Coder Social home page Coder Social logo

Comments (12)

polvalente avatar polvalente commented on June 19, 2024

You have to discount the initial run because that's also including XLA client initialization steps as well as JIT compilation time. If your timings are proportional to mine, you should get a considerable reduction in the execution time.

That being said, it seems that the majority of the time is spent calculating the FFT itself. I wonder if we'll have to do in Nx something similar to what we did with SVD. I'll look into it.

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

Ah, also, make sure to set Nx.default_backend(EXLA.Backend) too :)

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

Ok, I found the main problem: EXLA doesn't deal with radix-N FFT. If you use window_size: 512 or some power of 2, FFT is a lot faster. Also, padding: :reflect will slow down your calculations. I just noticed that your Pytorch code doesn't have padding='reflect' set anywhere, so maybe that also contributes.

Anyway, I got a speedup from 1s to 0.1s=100ms by using padding: :valid (which is the default), and then to ~30ms by changing the window size from 400 to 512 (same hop) or ~20ms@window_size=256. All of the times after a few warmup runs to ignore any memory allocation and JIT compilation effects on the measurement.

from nx_signal.

mortont avatar mortont commented on June 19, 2024

Yes, sorry I wasn't clear, that 1.5s runtime was after the JIT had been "warmed up". First run was 1.7s. I also left out my config where I set EXLA as the default backend, but I confirmed it is using EXLA.

config/config.exs:

import Config
config :nx, :default_backend, EXLA.Backend

I'm seeing similar numbers as yours after the changes, thanks! Although FWIW padding='reflect' is PyTorch's default, and explicitly setting it doesn't seem to alter the speed.

I wonder why it's still 10x slower? Typically EXLA is faster than PyTorch in my experience.

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

The great majority of the slowdown is due to the window size not being a power of 2. I'm gonna experiment with the fix I mentioned in the Nx issue, let's see where that goes.

To go into a bit more detail, the most straightforward/intuitive implementation of the Fast Fourier Transform algorithm is only Fast™ for vectors of power of 2 length. After all factors of 2 are exhausted (400 = 2⁴*25) the remainder is performed as a DFT instead (of length 25 in this case). So this amounts to a significant difference.

But there are radix-N algorithms that can deal with more prime factors for the recursion step, and thus are faster.

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

And even with getting them to comparable implementations, it might be the case that either XLA doesn't optimize complex numbers all that well, or the core FFT implementation isn't as fast as possible (which I doubt).

Given the tensor size we might be seeing issues with memory allocation time too, on the EXLA side of things.

Lots of factors to explore

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

By the way, have you tried running the FFT with a random (or at least nonzero) tensor in both libs?

edit: I just ran, so it isn't any zero check that makes it 5x faster

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

As mentioned in the Nx issue, this is an upstream issue with no immediate fix on our side.
tensorflow/tensorflow#6541

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

@mortont I think we can close this issue for now. Maybe we can open a separate one pad reflect itself.

Here's a benchmark I ran in Benchee using CUDA for both EXLA and Torchx:

Code

Mix.install [:benchee, {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, {:exla, github: "elixir-nx/nx", sparse: "exla"}, {:torchx, github: "elixir-nx/nx", sparse: "torchx"}], system_env: %{"XLA_TARGET" => "cuda114", "LIBTORCH_TARGET" => "cu117"}
Application.put_env(:exla, :clients, [host: [platform: :host], cuda: [platform: :cuda, preallocate: true, memory_fraction: 0.5]])

defmodule TorchxFFT do
  import Nx.Defn
    
  defn fft(t) do
    while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
      {Nx.fft(t + i)}
    end
  end
end

Nx.Defn.default_options(compiler: EXLA)
defmodule EXLAFFT do
  import Nx.Defn
    
  defn fft(t) do
    while {t = Nx.as_type(t, :c64)}, i <- 1..10, unroll: true do
      {Nx.fft(t + i)}
    end
  end
end

Benchee.run(%{"EXLA while" => fn input -> EXLAFFT.fft(input.exla) end, "EXLA fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.exla) end), "Torchx fft" => &Enum.each(1..10, fn _ -> Nx.fft(&1.torchx) end), "Torchx while" => fn input -> TorchxFFT.fft(input.torchx) end}, inputs: %{
  "1x400" => %{exla: Nx.iota({1, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 400}, backend: {Torchx.Backend, device: :cuda})}, 
  "1x512" => %{exla: Nx.iota({1, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({1, 512}, backend: {Torchx.Backend, device: :cuda})}, 
  "3kx400" => %{exla: Nx.iota({3000, 400}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 400}, backend: {Torchx.Backend, device: :cuda})}, 
  "3kx512" => %{exla: Nx.iota({3000, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({3000, 512}, backend: {Torchx.Backend, device: :cuda})},
  "4096x512" => %{exla: Nx.iota({4096, 512}, backend: {EXLA.Backend, client: :cuda, preallocate: false}), torchx: Nx.iota({4096, 512}, backend: {Torchx.Backend, device: :cuda})},
}, warmup: 5, time: 10); nil

The idea was to check if while had a positive or negative impact in the "let's run this computation 10 times on the GPU" scenario.

Results

With input 1x400

Name ips average deviation median 99th %
Torchx fft 1599.15 0.63 ms ±13.99% 0.61 ms 0.97 ms
Torchx while 893.97 1.12 ms ±10.95% 1.10 ms 1.55 ms
EXLA fft 855.54 1.17 ms ±12.23% 1.14 ms 1.67 ms
EXLA while 268.87 3.72 ms ±10.61% 3.63 ms 4.96 ms

Comparison:
Torchx fft 1599.15
Torchx while 893.97 - 1.79x slower +0.49 ms
EXLA fft 855.54 - 1.87x slower +0.54 ms
EXLA while 268.87 - 5.95x slower +3.09 ms

With input 1x512

Name ips average deviation median 99th %
Torchx fft 1472.02 0.68 ms ±15.55% 0.65 ms 0.97 ms
Torchx while 906.34 1.10 ms ±14.54% 1.07 ms 1.63 ms
EXLA fft 861.77 1.16 ms ±11.67% 1.13 ms 1.62 ms
EXLA while 277.07 3.61 ms ±8.88% 3.54 ms 4.61 ms

Comparison:
Torchx fft 1472.02
Torchx while 906.34 - 1.62x slower +0.42 ms
EXLA fft 861.77 - 1.71x slower +0.48 ms
EXLA while 277.07 - 5.31x slower +2.93 ms

With input 3kx400

Name ips average deviation median 99th %
Torchx while 231.23 4.32 ms ±2.77% 4.32 ms 4.57 ms
Torchx fft 224.89 4.45 ms ±0.97% 4.45 ms 4.49 ms
EXLA fft 143.81 6.95 ms ±6.81% 6.98 ms 8.00 ms
EXLA while 94.47 10.58 ms ±5.62% 10.53 ms 12.01 ms

Comparison:
Torchx while 231.23
Torchx fft 224.89 - 1.03x slower +0.122 ms
EXLA fft 143.81 - 1.61x slower +2.63 ms
EXLA while 94.47 - 2.45x slower +6.26 ms

With input 3kx512

Name ips average deviation median 99th %
Torchx while 231.56 4.32 ms ±1.61% 4.30 ms 4.55 ms
Torchx fft 181.71 5.50 ms ±0.79% 5.50 ms 5.57 ms
EXLA fft 137.46 7.27 ms ±7.17% 7.28 ms 8.46 ms
EXLA while 94.11 10.63 ms ±4.05% 10.62 ms 11.61 ms

Comparison:
Torchx while 231.56
Torchx fft 181.71 - 1.27x slower +1.18 ms
EXLA fft 137.46 - 1.68x slower +2.96 ms
EXLA while 94.11 - 2.46x slower +6.31 ms

With input 4096x512

Name ips average deviation median 99th %
Torchx while 174.26 5.74 ms ±1.66% 5.74 ms 5.97 ms
Torchx fft 134.62 7.43 ms ±0.48% 7.43 ms 7.48 ms
EXLA fft 112.89 8.86 ms ±5.64% 8.90 ms 10.00 ms
EXLA while 82.10 12.18 ms ±5.32% 12.08 ms 13.90 ms

Comparison:
Torchx while 174.26
Torchx fft 134.62 - 1.29x slower +1.69 ms
EXLA fft 112.89 - 1.54x slower +3.12 ms
EXLA while 82.10 - 2.12x slower +6.44 ms

Conclusion

I don't quite understand how Torchx with while ends up being more performant than with for since the Defn evaluator is also "pure" Elixir (cc @josevalim). However, the results were consistent on every input kind.

It's worth noting that while the relative differences are significant, we're talking about differences of less than 5ms between EXLA fft and Torchx fft, which are directly equivalent, and of less than 10ms between "Torchx while" and "EXLA while", which are probably the more natural use cases.

The point I want to make is not that 10ms is an insignificant difference, but is miles away from the many tens of ms in the best case scenario we had with the CPU execution. It's important to ensure that (cuFFT is installed)[https://docs.nvidia.com/cuda/cufft/index.html].

Also note that I benchmarked only FFT because that's where we have less control.
It might be worth opening a separate issue if even with the assurances of using cuFFT you still have a less than desirable performance with STFT. padding: :reflect was a problematic option to implement, so bear that in mind when .

from nx_signal.

josevalim avatar josevalim commented on June 19, 2024

Only the function dispatch is pure Elixir, it still calls the backend functions so I assume the FFT/TFFT code in XLA is not the fastest? Does Jax does anything in special to speed the FFT code?

from nx_signal.

polvalente avatar polvalente commented on June 19, 2024

@josevalim I was referring specifically to the fact that the torchx code using while was somehow faster than the elixir code using for. Probably an artifact of the benchmark, though.

from nx_signal.

mortont avatar mortont commented on June 19, 2024

Great, thanks for the exhaustive benchmark!

From the upstream thread you mentioned, it looks like Jax has worked around the slow FFT implementation by using PocketFFT rather than XLA (Eigen) specifically for FFTs on the CPU google/jax#2952. Is that something EXLA could do?

With that said, changing the FFT size to a power of 2 and changing STFT window padding to same has fixed the performance issue for my application, so I'll close this issue.

from nx_signal.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.