-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Improve latency of a RandomForestRegressor.predict #16310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… and check_input are disabled
@@ -743,7 +743,7 @@ def __init__(self, | |||
warm_start=warm_start, | |||
max_samples=max_samples) | |||
|
|||
def predict(self, X): | |||
def predict(self, X, check_input=True, parallel_predict=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do this with context managers and to not add public parameters.
eg using:
with parallel_backend('threading', n_jobs=1):
clf.predict(X)
and we should be able to do:
with sklearn.config_context(check_input=False):
....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DecisionTreeRegressor still has a parameter check_input=False (https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html#sklearn.tree.DecisionTreeRegressor.predict). Which one should be used when check_input values are not the same?
I extended the script to compare the predict method to two onnx runtimes (onnxruntime, mlprodict) which work well on small batches:
When the batch size is 1000 (instead of 1 above), results are comparable:
When the batch size is 100.000, onnxruntime is the slowest, the other runtime is still fast.
|
Skipping the validation checks on specific estimators will be significantly either once the PR on |
On master, the default is |
Not sure what the status is about this PR. I suggest closing it. |
Let's close (but keep #16143 open). We need to redo profiling on the main branch. I suspect that even when |
Reference Issues/PRs
Proposes one fix for #16143.
What does this implement/fix? Explain your changes.
Allow the user to disable the parallelisation or to set check_input=False when calling predict for a RandomForestRegression. The prediction time is divided by 3 by parallelisation is disabled, check_input=False with 100 trees, 10 features, and 1 observation.
Profile with py-spy: