How to type hint `flax.linen.Module.apply`'s output correctly?

How to type hint `flax.linen.Module.apply`'s output correctly?
python
Ethan Jackson

As of writing, this code does not pass the PyRight type checker:

import jax import jax.numpy as jnp import jax.typing as jt import flax.linen as nn class MLP(nn.Module): @nn.compact def __call__(self, inputs: jt.ArrayLike): inputs = jnp.array(inputs) return nn.Dense(4)(inputs) if __name__ == "__main__": inputs = jnp.ones((2, 4)) mlp = MLP() prng_key = jax.random.key(7) outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs) print(f"type of outputs is {type(outputs)}")

Its output is: type of outputs is <class 'jaxlib.xla_extension.ArrayImpl'>.

On like outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs), Pylance shows this error:

Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to declared type "Array" Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to type "Array" "tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to "Array" PylancereportAssignmentType

Any other operations that treat outputs like a JAX array succeeds at runtime but also fails to type check.

Yet the output of that function is an array. This code type checks correctly:

import jax import jax.numpy as jnp x: jax.Array = jnp.zeros((2, 2)) print(f"x's type is {type(x)}")

and its output is: x's type is <class 'jaxlib.xla_extension.ArrayImpl'>, the same type.

How should I use this module to get the output array in a way that pyright can type check correctly?

Answer

The error you're seeing from PyRight is due to the fact that the apply method of a Flax module can return a tuple, which includes the output and some additional state information. This can lead to type mismatches when you try to assign it directly to a variable of type jax.Array

Try following changes:

import jax import jax.numpy as jnp import jax.typing as jt import flax.linen as nn class MLP(nn.Module): @nn.compact def __call__(self, inputs: jt.ArrayLike): inputs = jnp.array(inputs) return nn.Dense(4)(inputs) if __name__ == "__main__": inputs = jnp.ones((2, 4)) mlp = MLP() prng_key = jax.random.key(7) # unpack the output outputs, _ = mlp.apply(mlp.init(prng_key, inputs), inputs) # type hint outputs correctly outputs: jax.Array = outputs # pass the type checker print(f"type of outputs is {type(outputs)}")

Related Articles