Coder Social home page Coder Social logo

Comments (6)

MhLiao avatar MhLiao commented on June 22, 2024 1

We modified the CRNN codes to output the probability.
You can refer to this paper: http://www.machinelearning.org/proceedings/icml2006/047_Connectionist_Tempor.pdf

from textboxes.

MhLiao avatar MhLiao commented on June 22, 2024 1

@ahmedmazari-dhatim You can refer to the Equation. 14 in the given paper which describes the CTC to get the probability. There is a variable named logProb in "crnn/src/cpp/ctc.cpp", you can get the score by an "exp" operation.

from textboxes.

ahmedmazari-dhatim avatar ahmedmazari-dhatim commented on June 22, 2024 1

Hello @MhLiao ,

Thank you a lot for your answer. However, l'm using Pytorch version this is why l'm asking the question. l don"t have access from pytorch version of CRNN to crnn/src/cpp/ctc.cpp.

in crrn_main.py l have the following :
criterion = CTCLoss()
such that CTCLoss() is :

import torch
import warpctc_pytorch as warp_ctc
from torch.autograd import Function
from torch.nn import Module
from torch.nn.modules.loss import _assert_no_grad
from torch.utils.ffi import _wrap_function
from ._warp_ctc import lib as _lib, ffi as _ffi

__all__ = []


def _import_symbols(locals):
    for symbol in dir(_lib):
        fn = getattr(_lib, symbol)
        locals[symbol] = _wrap_function(fn, _ffi)
        __all__.append(symbol)


_import_symbols(locals())


class _CTC(Function):
    def forward(self, acts, labels, act_lens, label_lens):
        is_cuda = True if acts.is_cuda else False
        acts = acts.contiguous()
        loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
        grads = torch.zeros(acts.size()).type_as(acts)
        minibatch_size = acts.size(1)
        costs = torch.zeros(minibatch_size)
        loss_func(acts,
                  grads,
                  labels,
                  label_lens,
                  act_lens,
                  minibatch_size,
                  costs)
        self.grads = grads
        self.costs = torch.FloatTensor([costs.sum()])
        return self.costs

    def backward(self, grad_output):
        return self.grads, None, None, None


class CTCLoss(Module):
    def __init__(self):
        super(CTCLoss, self).__init__()

    def forward(self, acts, labels, act_lens, label_lens):
        """
        acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
        """
        _assert_no_grad(labels)
        _assert_no_grad(act_lens)
        _assert_no_grad(label_lens)
        return _CTC()(acts, labels, act_lens, label_lens)



l'm wondering if there is a way from the pytorch version to get the probabilities as you suggested . Any idea @MhLiao to get that ?

in
the according line to get the probabilities is line 80

// compute log-likelihood
T logProb = fvars.at({inputLength-1, nSegment-1});

Thank you again

from textboxes.

ahmedmazari-dhatim avatar ahmedmazari-dhatim commented on June 22, 2024

Hi @MhLiao,

Thank you a lot for your answer. However l don't find where output the probabilities in CRNN, do you mind tell me where can l print them?

Thank you @MhLiao

from textboxes.

MhLiao avatar MhLiao commented on June 22, 2024

@ahmedmazari-dhatim I am sorry that I did not read the py-torch code. But I guess the py-torch code also utilize the CTC-wrap which is written in C++.

from textboxes.

ahmedmazari-dhatim avatar ahmedmazari-dhatim commented on June 22, 2024

Hi @MhLiao ,
Yes but can't find how to access cpp/ctc.cpp from pytorch version.

We have only this class in pytorch
CTCLOSS()

CTCLoss() is :


import torch
import warpctc_pytorch as warp_ctc
from torch.autograd import Function
from torch.nn import Module
from torch.nn.modules.loss import _assert_no_grad
from torch.utils.ffi import _wrap_function
from ._warp_ctc import lib as _lib, ffi as _ffi

__all__ = []


def _import_symbols(locals):
    for symbol in dir(_lib):
        fn = getattr(_lib, symbol)
        locals[symbol] = _wrap_function(fn, _ffi)
        __all__.append(symbol)


_import_symbols(locals())


class _CTC(Function):
    def forward(self, acts, labels, act_lens, label_lens):
        is_cuda = True if acts.is_cuda else False
        acts = acts.contiguous()
        loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc
        grads = torch.zeros(acts.size()).type_as(acts)
        minibatch_size = acts.size(1)
        costs = torch.zeros(minibatch_size)
        loss_func(acts,
                  grads,
                  labels,
                  label_lens,
                  act_lens,
                  minibatch_size,
                  costs)
        self.grads = grads
        self.costs = torch.FloatTensor([costs.sum()])
        return self.costs

    def backward(self, grad_output):
        return self.grads, None, None, None


class CTCLoss(Module):
    def __init__(self):
        super(CTCLoss, self).__init__()

    def forward(self, acts, labels, act_lens, label_lens):
        """
        acts: Tensor of (seqLength x batch x outputDim) containing output from network
        labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
        act_lens: Tensor of size (batch) containing size of each output sequence from the network
        act_lens: Tensor of (batch) containing label length of each example
        """
        _assert_no_grad(labels)
        _assert_no_grad(act_lens)
        _assert_no_grad(label_lens)
        return _CTC()(acts, labels, act_lens, label_lens)


from textboxes.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.