Comments (3)
I think the issue is a bug on line 239 of forestci.py. I think this line should be replaced by
pred_mean = np.mean(pred, 1)
pred_centered = (pred.T - pred_mean).T
because we want to average over the bootstrap dimension, not the test dimension.
from forest-confidence-interval.
@agrawalraj I have also faced this issue some time ago.
The solution I found back then is similar to yours.
While I was debugging the function from which you were copying lines I found out:
'pred' had the following dimensions:
0: the samples
1: the prediction for each tree
It is the result of this line:
pred = np.array([tree.predict(X_test) for tree in forest]).T
Due to the dimension of the prediction array for one sample the mean calculation might not return the result that is expected.
Either way, it does not make sense to average the prediction of different samples for the same tree instead of averaging the predictions of all trees of the forest for one sample.
This did the fix for me:
pred_mean = np.mean(pred, 1).reshape(X_test.shape[0], 1)
Nothing else had to be changed.
Maybe it would be benefitial to include the single (test) sample case (as in LOOCV) in the code tests.
from forest-confidence-interval.
We'd welcome a pull request
from forest-confidence-interval.
Related Issues (20)
- cannot import name '_get_n_samples_bootstrap' HOT 3
- ValueError on multiple output problems HOT 1
- Sum taken over wrong axis HOT 2
- Can this package be adapted to perform Thompson sampling?
- Compatibility issues with scikit-learn 0.24.2 HOT 1
- Benchmarking confidence intervals HOT 1
- Array dimensions incorrect for confidence intervals HOT 12
- New Release HOT 1
- progress indicator?
- random_forest_error() does not work without scalers.
- Unnecessary usage of training data?
- Warning: sklearn.ensemble.forest module is deprecated in version 0.22 HOT 3
- Can't uninstall forestci HOT 1
- Overflow errors HOT 4
- Not compatible with SKLearn version 0.22.1 HOT 4
- amount of trees needed to work
- All 0's in `g_eta_raw`.
- forest error are all NaN HOT 8
- Applicability to non-binary classification tasks
- Allow general bagging estimators
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from forest-confidence-interval.