Comments (11)
Thank you for pointing this out!
I spent a while digging into it, and you are correct. I think the intention when we did the refactoring in #2777 was to pass the input to the layer, like you pointed out. That is especially useful for some layers (MeanPooling
and MaxPooling
for example). However, I see two errors: FFN::Backward()
is passing the layer outputs to network->Backward()
(which is MultiLayer::Backward()
), and then all the internal calls there are off by one in the layer outputs they pass.
It is possible to rederive LogSoftmax
and Softmax
such that they use the input values, not the output; but, we have both available in MultiLayer
and FFN
. I wonder if it is better to simply change the API of Backward()
:
void Backward(const MatType& forwardInput,
const MatType& forwardOutput,
const MatType& gy,
MatType& delta);
I'd be curious what @zoq and @Aakash-kaushik and others thought before we moved on that. It wouldn't be a high-effort refactoring and I don't see any downsides (except the marginally increased complexity added by another parameter), but maybe I overlooked something. Also related: #3470 and @mrdaybird's related observations there.
from mlpack.
I guess it would be simpler to do iit all together, since I've already started down that path. Just wasn't sure if you prefer having separate changes like this merged separately.
from mlpack.
This would certainly be the most flexible solution, as some layers need the input and some benefit from having the output (in efficiency). At the very least, I think it should be changed to the input. An alternative (though I guess maybe less memory-efficient) is done already in certain places, is where any layers that need their outputs in the backwards/gradient methods just remember it during the forward pass. It already has the contract that backward is not called until Forward has been called.
I am happy to do the re-factoring if you want. I have already done some of it in the branch mentioned above (when I started, I didn't realize the extent of the issue). I can pull the relevant changes over to a new branch explicitly for this issue.
from mlpack.
I chatted with @zoq, he agrees that adding an extra parameter is a reasonable thing to do. If you're up for the refactoring I would really appreciate that, and I can review once it's ready. Up to you whether you want to do both as part of #3550 or make a new PR---the changes are orthogonal enough that while reviewing I should be able to keep them separated in my head. Thanks again for pointing this out. 👍
from mlpack.
So just noticed in the comment for Backward() on MultiLayer, it seems that "input" was intended to be the output, at least from the comments:
* Perform a backward pass with the given data. `gy` is expected to be the
* propagated error from the subsequent layer (or output), `input` is expected
* to be the output from this layer when `Forward()` was called, and `g` will
* store the propagated error from this layer (to be passed to the previous
* layer as `gy`).
from mlpack.
Committed (most of) the changes needed to the branch for #3547. The one remaining bit (that I know of) is that I believe most of the activation functions are written assuming that Deriv() is passed the output from the layer, but some (e.g. InvQuadFunction) seem to be written assuming the input will be passed to Deriv(). I can try to go through to identify which is which, but then what do we do about it?
My thought was to perhaps have a FunctionTraits<> template class which indicates whether Deriv expects the output or the input. That was whichever one is more efficient to implement could be used. BaseLayer could have two implementations of Backward based on the indicator. Either that, or have two differently named methods on the activation functions depending on which one is expected.
from mlpack.
Yeah, sometimes the derivative is actually easier to express in terms of the input. If Backward()
already needs both the inputs and the outputs available, I feel like there's a good argument for making exactly the same change to Deriv()
, and supplying both the input and output. I took a quick look through and saw that Softplus
also wants the input, not the output.
The FunctionTraits
idea could work, although it wouldn't cover the case where it happens to be maximally convenient to express the derivative in terms of the output and the input. (For example take the random function I just made up y = exp(x^2)
, where the derivative dy = exp(y) 2x
.) But, happy to put off considering that situation until it's actually encountered.
@mrdaybird's work in #3478 (comment) is directly relevant for this second issue (it is exactly the same issue). In fact the solution proposed there is basically equivalent to using FunctionTraits
, although in this case instead of two different classes, we can keep the code all in BaseLayer
and just branch according to the compile-time property of whether or not we want the input or output.
from mlpack.
Yeh, I guess the cleanest/most flexible solution might be to pass both the input and output to Deriv, although that feels a little strange for some reason. ;). If this were a general function abstraction, I'm not sure I'd like that, but considering it is only used here, it's probably ok.
If we went for the Deriv from either input or output, I guess it could also be something like you just implement one of DerivFromX or DerivFromY.
I'm happy to change it to whichever you think makes the most sense here.
from mlpack.
What if I make activation functions support one (and only one) of the following:
DerivFromInput
DerivFromOutput
DerivFromBoth
The implementer could choose to implement whichever one makes the most sense. The name makes it clear what you are implementing. BaseLayer could then have three versions of Backward, depending on which one of these is implemented.
Or... I could just make it take both all the time.
from mlpack.
Up to you, I don't have a particularly strong opinion here. It's probably a nice thing to do to rename the function to indicate whether it's input or output. 👍
from mlpack.
Fixed in #3547. 👍
from mlpack.
Related Issues (20)
- NVP mentioned when loading models from Cereal XML file, but not mentioned when model is saved HOT 3
- BLAS, LAPACK dependencies not installed when using DOWNLOAD_DEPENDENCIES flag HOT 1
- Example of input data format for HMMs HOT 6
- Linker errors when building mlpack from source - Linux HOT 4
- Error bulding golang binding on MacOS M2 HOT 2
- Error while installing mlpack
- Circular includes lead to compilation errors HOT 5
- For ANN (c++) Remove Templates for OutputLayerType and InitializationRuleType HOT 7
- Can you include `mlpack.hpp` into the python wheel? HOT 6
- Security improvement: get_deps with https urls? HOT 10
- MLPACK_STRING_VIEW error HOT 13
- Add accuracy measure for mlpack_logistic_regression ? HOT 7
- Error in use of preprocess_split() in LinearRegression. HOT 2
- Proposal for New Supervised Learning Data Simulation Classes in C++ for MLPACK Library HOT 5
- Build fail for project using MLPack, M_SQRT2 undeclared identifier HOT 4
- MPLACK failed build in x86 mode HOT 1
- msvc round in core/math/round.hpp HOT 2
- Multihead attention layer HOT 1
- Multi-head output network HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mlpack.