Skip to content

Speed Optimization Basics: Numba

When to use Numba

  • Numba works well when the code relies a lot on (1) numpy, (2) loops, and/or (2) cuda.
  • Hence, we would like to maximize the use of numba in our code where possible where there are loops/numpy

Numba CPU: nopython

  • For a basic numba application, we can cecorate python function thus allowing it to run without python interpreter
  • Essentially, it will compile the function with specific arguments once into machine code, then uses the cache subsequently

With Numba: no python

from numba import jit, prange
import numpy as np

# Numpy array of 10k elements
input_ndarray = np.random.rand(10000).reshape(10000)

# This is the only extra line of code you need to add
# which is a decorator
@jit(nopython=True)
def go_fast(a):
    trace = 0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i])
    return a + trace

%timeit go_fast(input_ndarray)
161 µs ± 2.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Without numba

  • This is much slower, time measured in the millisecond space rather than microsecond with @jit(nopython=True) or @njit
# Without numba: notice how this is really slow
def go_normal(a):
    trace = 0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i])
    return a + trace

%timeit go_normal(input_ndarray)
10.5 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Numba CPU: parallel

  • Here, instead of the normal range() function we would use for loops, we would need to use prange() which allows us to execute the loops in parallel on separate threads
  • As you can see, it's slightly faster than @jit(nopython=True)
@jit(nopython=True, parallel=True)
def go_even_faster(a):
    trace = 0
    for i in prange(a.shape[0]):
        trace += np.tanh(a[i])
    return a + trace

%timeit go_even_faster(input_ndarray)
148 µs ± 71.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Numba CPU: fastmath

  • What if we relax our condition of strictly adhering to IEEE 754.
  • We can have faster performance (depends)
  • I would say this is the least additional speed-up unless you really dig into areas where fastmath=True thrives
@jit(nopython=True, parallel=True, fastmath=True)
def go_super_fast(a):
    trace = 0
    for i in prange(a.shape[0]):
        trace += np.tanh(a[i])
    return a + trace

%timeit go_super_fast(input_ndarray)
113 µs ± 39.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Summary

  • When to use Numba
    • (1) numpy array or torch tensors,
    • (2) loops, and/or
    • (3) cuda
  • Numba CPU: nopython¶
  • Numba CPU: parallel
  • Numba CPU: fastmath

Comments