Comments (2)
Here is my implement of the complex drop, Can you help to see if this is correct?
def complex_dropout(input_r, input_i, p=0.5, training=True, inplace=False):
if not training:
return input_r, input_i
bernoulli_dist = torch.from_numpy(np.random.binomial(1, 1-p, input_r.shape))
d_r, d_i = input_r.masked_fill(bernoulli_dist == 0, 0.), input_i.masked_fill(bernoulli_dist == 0, 0.)
# TODO: inplace implement
return d_r, d_i
class ComplexDropout(nn.Module):
def __init__(self, p=0.5, inplace=False):
super(ComplexDropout, self).__init__()
self.p = p
self.inplace = inplace
def forward(self, input_r, input_i):
return complex_dropout(input_r, input_i, self.p, self.training, self.inplace)
from complexpytorch.
Indeed, the dropout was not correct as the real and imaginary parts did not drop the same elements. I also had a related issue with max_pool.
I rewrote the entire thing to use the new complex tensors, I corrected the drop_out.
I use a not so elegant solution, I apply a dropout to a tensor filled with ones and use the result as a mask that I multiplied the complex tensor with.
def complex_dropout(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
mask = torch.ones_like(input).type(torch.float32)
mask = dropout(mask, p, training)*1/(1-p)
return mask*input
from complexpytorch.
Related Issues (20)
- UserWarning: Casting complex values to real discards the imaginary part Issue HOT 1
- can not compute in parallel HOT 3
- RuntimeError: imag is not implemented for tensors with non-complex dtypes. HOT 1
- ConvTranspose2d error HOT 4
- How to feed the input of two dimensions to the model HOT 3
- A bug in ComplexBatchNorm1d HOT 1
- Error with forward propagation with ComplexBatchNorm1d HOT 4
- Use native Pytorch operations for complex numbers HOT 2
- ComplexBatchNorm1d Error HOT 3
- About operating efficiency and convergence HOT 2
- dropout need to set mask's device same as input Tensor! HOT 5
- ComplexConvTransposeNd
- Complex MSELoss
- Configuration for each channel
- Getting Error while Importing Tanh, Sigmoid
- ComplexDropout2d Device Error
- torchinfo.summary() is not able to executable HOT 2
- How to ensure that the output of a convolutional or BN layer meets Crelu's input requirements HOT 1
- Another batch norm suggestion
- Autograd and backpropagation
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from complexpytorch.