Introdução ao JAX
“Numpy“ em CPU, GPU e TPU.
Este artigo possui a intenção de apresentar o JAX, seus casos de uso e algumas transformações.

O que é e para que serve?
JAX (Just After eXecution) é um pacote Python criado pela Google Research que tem como objetivo acelerar a performance de sua pesquisa em Machine Learning (ML). Para isso, ele utiliza um conjunto de funções que permitem tirar o máximo proveito do compilador XLA (mais detalhes no próximo paragrafo), resultando em uma alta performance na execução dos seus algoritmos de ML (esses pontos ficarão mais evidentes no decorrer do texto). Tudo isso é oferecido através de uma API NumPy-like, ou seja, você codificará com JAX semelhantemente a como você codifica usando Numpy.

O Accelerated Linear Algebra (XLA) é um compilador de domínio específico para álgebra linear que permite transitar entre CPU, GPU e TPU sem adaptações no nosso código Python. Para utilizar o compilador XLA com o máximo de eficiência, JAX rastreia as funções e as traduz para uma linguagem que permite interpretar e otimizar no compilador XLA, essa linguagem é chamada de Jaxpr e você pode entender melhor clicando aqui.

Algo interessante a comentar é que JAX é um framework recente (seu lançamento foi em dezembro de 2018), mas que já ganhou alguma notoriedade, aparecendo em terceiro lugar nos Trends do Papers With Code.

Vale ressaltar que JAX não é tratado como um produto, mas sim como uma ferramenta de pesquisa.
Alguns links que podem ser úteis:
Transformações
JAX possuí muitas funções de transformação, aqui irei apenas resumir as principais. Saiba que todas as transformações no JAX foram pensadas para funcionarem de forma composta, isto é, você conseguirá usar muitas transformações em conjunto sem problemas.
grad
Essa transformação recebe uma função e retorna uma nova função que possibilita calcular o gradiente da função original. É bastante útil em otimização de funções complexas de forma iterativamente.

jit
Just-in-time (JIT), trata-se de compilar o código em tempo de execução. No caso do JAX, o JIT permite transformar uma função escrita em Python em um artefato compilado (especificamente um Jaxpr) para o compilador XLA. Isso resulta em uma otimização da função para sua execução no XLA.
Vale ressaltar que nem toda função em Python conseguirá ser compilada corretamente, mas no contexto de ML (onde as funções em sua maioria são matemáticas) é bem provável que consiga.

vmap
Permite criar uma versão vetorizada de uma função de forma automática. É um ótimo adicional para funções compiladas pelo JIT, deixando sua função ainda mais rápido.

pmap
Serve para executarmos nossas funções de forma paralela automaticamente.

Vale observar que o retorno do pmap foi um tipo diferente (ShardedDeviceArray), justamente porque ele divide os elementos da matriz em todos os dispositivos usados no paralelismo.
Quando usar?
No contexto de pesquisa em ML, principalmente em Deep Learning (DL), é interessante executarmos nossos modelos em hardwares que possuem mais poder computacional. O Numpy por si só não executa em GPU e TPU, o que faz com que os pesquisadores precisem utilizar outras ferramentas, como CuPy e Numba. O JAX nos permite executar o mesmo código em diferentes dispositivos sem muito trabalho ou adaptação do nosso lado.
Ele também é recomendado para tarefas que exigem um processamento grande de dados, possuindo mecanismos que facilitar a paralelização (pmap). Além disso, possui outras formas de otimização, como o JIT (just-in-time) que compila o código em tempo de execução e a vetorização de funções (vmap).
Devido a sua liberdade e eficiência, creio que seja uma ferramenta bastante adequada para times mais maduros e onde o foco seja as otimizações de um modela já bem estabelecido.
Quando não usar?
Ao meu ver, a principal desvantagem do JAX é a tarefa de desenvolver alguns modelos de ML/DL do zero, sem uso de modelos já prontos e oferecidos pelo pacote. Para empresas novas ou até times novos, onde o foco inicial é a entrega de um produto usável para o cliente, creio que o uso de framework como PyTorch ou Tensorflow sejam melhores opções.
Uma desvantagem é que os arrays no JAX são imutáveis, diferentemente dos arrays em Numpy. Então se você costuma fazer atribuições de valor direto no seu array isso pode te atrapalhar um pouco.
Referências
A referência utilizada para a escrita deste artigo foi a documentação oficial do JAX. Ela é bastante completa e contêm muitos exemplos (inclusive os que aparecem nesse artigo).