-
Notifications
You must be signed in to change notification settings - Fork 4
Add LightGBM classifier implementation #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Implement LightGBMClassifier in sml/ensemble/lightgbm.py - Supports gradient boosting with decision trees - Parameters: n_estimators, learning_rate, max_depth, num_leaves, criterion, epsilon - Methods: fit(), predict(), decision_function() - Consistent with AdaBoost and RandomForest implementations - Add comprehensive tests in tests/ensemble/lightgbm_test.py - Test against sklearn GradientBoostingClassifier - Uses Iris dataset with boolean feature transformation - Add emulation tests in emulations/ensemble/lightgbm_emul.py - Performance benchmarking with sklearn - SPU execution comparison - Update ensemble __init__.py to export LightGBMClassifier
Summary of ChangesHello @YuzeHao2023, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new LightGBM classifier implementation, expanding the machine learning capabilities of the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new LightGBMClassifier, along with corresponding tests and emulation scripts. My review focuses on the correctness of the implementation, its maintainability, and the quality of the tests. The main issue is that the classifier's name LightGBMClassifier is misleading, as the implementation appears to be a variant of AdaBoost (SAMME.R) rather than the LightGBM algorithm. I've also identified an unused parameter in the classifier's constructor and noted that the tests lack assertions to verify correctness. Additionally, there's some duplicated code between the test and emulation files that could be refactored. Please see my detailed comments for suggestions on how to address these points.
| from sml.tree.tree import DecisionTreeClassifier as sml_dtc | ||
|
|
||
|
|
||
| class LightGBMClassifier: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class is named LightGBMClassifier, but the implementation follows the AdaBoost.SAMME.R algorithm, not LightGBM or even standard gradient boosting. The weight update rule in _boost_round (lines 220-226) is characteristic of AdaBoost. This name is misleading and can cause confusion for users. Please rename the class to something that accurately reflects the algorithm, such as SAMMEClassifier.
| num_leaves : int | ||
| The maximum number of leaves in each tree. | ||
| Must be greater than 1. Default is 31. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The num_leaves parameter is defined in the docstring and accepted in __init__, but it is not used anywhere in the implementation. The DecisionTreeClassifier is instantiated using max_depth only. This can be misleading for users of the class. Please remove the num_leaves parameter from the docstring, the __init__ method signature, its validation, and its assignment to self.num_leaves.
| print(f"Accuracy in SKlearn: {score_plain}") | ||
| print(f"Accuracy in SPU: {score_encrypted}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test calculates accuracy scores but only prints them. A unit test must contain assertions to automatically verify the correctness of the implementation. Please add an assertion to check if the score from the SPU implementation is close to the score from the sklearn implementation. For example: assert jnp.isclose(score_plain, score_encrypted, atol=0.05).
| def load_data(): | ||
| iris = load_iris() | ||
| iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) | ||
| # sorted_features: n_samples * n_features_in | ||
| n_samples, n_features_in = iris_data.shape | ||
| sorted_features = jnp.sort(iris_data, axis=0) | ||
| new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 | ||
| new_features = jnp.greater_equal( | ||
| iris_data[:, :], new_threshold[:, jnp.newaxis, :] | ||
| ) | ||
| new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) | ||
|
|
||
| X, y = new_features[:, ::3], iris_label[:] | ||
| return X, y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # limitations under the License. | ||
|
|
||
|
|
||
| import copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pred = sum( | ||
| jnp.where( | ||
| (estimator.predict(X) == classes).T, | ||
| w, | ||
| -1 / (n_classes - 1) * w, | ||
| ) | ||
| * flag | ||
| for estimator, w, flag in zip( | ||
| self.estimators_, | ||
| self.estimator_weight_, | ||
| self.estimator_flags_, | ||
| strict=True, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decision_function uses a Python sum() on a generator expression. This pattern is not compatible with JAX's JIT compilation and can be inefficient for a large number of estimators due to Python-level iteration. For better performance and to enable JIT compilation, consider refactoring this to use jax.lax.fori_loop to accumulate the predictions.
| def load_data(): | ||
| iris = load_iris() | ||
| iris_data, iris_label = jnp.array(iris.data), jnp.array(iris.target) | ||
| # sorted_features: n_samples * n_features_in | ||
| n_samples, n_features_in = iris_data.shape | ||
| sorted_features = jnp.sort(iris_data, axis=0) | ||
| new_threshold = (sorted_features[:-1, :] + sorted_features[1:, :]) / 2 | ||
| new_features = jnp.greater_equal( | ||
| iris_data[:, :], new_threshold[:, jnp.newaxis, :] | ||
| ) | ||
| new_features = new_features.transpose([1, 0, 2]).reshape(n_samples, -1) | ||
|
|
||
| X, y = new_features[:, ::3], iris_label[:] | ||
| return X, y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) | ||
|
|
||
| X, y = load_data() | ||
| n_samples, n_features = X.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement LightGBMClassifier in sml/ensemble/lightgbm.py
Add comprehensive tests in tests/ensemble/lightgbm_test.py
Add emulation tests in emulations/ensemble/lightgbm_emul.py
Update ensemble init.py to export LightGBMClassifier