How to type hint `flax.linen.Module.apply`'s output correctly?

As of writing, this code does not pass the PyRight type checker:
import jax
import jax.numpy as jnp
import jax.typing as jt
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, inputs: jt.ArrayLike):
inputs = jnp.array(inputs)
return nn.Dense(4)(inputs)
if __name__ == "__main__":
inputs = jnp.ones((2, 4))
mlp = MLP()
prng_key = jax.random.key(7)
outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs)
print(f"type of outputs is {type(outputs)}")
Its output is: type of outputs is <class 'jaxlib.xla_extension.ArrayImpl'>
.
On like outputs: jax.Array = mlp.apply(mlp.init(prng_key, inputs), inputs)
, Pylance shows this error:
Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to declared type "Array"
Type "Any | tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to type "Array"
"tuple[Any, FrozenVariableDict | dict[str, Any]]" is not assignable to "Array"
PylancereportAssignmentType
Any other operations that treat outputs
like a JAX array succeeds at runtime but also fails to type check.
Yet the output of that function is an array. This code type checks correctly:
import jax
import jax.numpy as jnp
x: jax.Array = jnp.zeros((2, 2))
print(f"x's type is {type(x)}")
and its output is: x's type is <class 'jaxlib.xla_extension.ArrayImpl'>
, the same type.
How should I use this module to get the output array in a way that pyright can type check correctly?
Answer
The error you're seeing from PyRight
is due to the fact that the apply method of a Flax
module can return a tuple, which includes the output and some additional state information. This can lead to type mismatches when you try to assign it directly to a variable of type jax.Array
Try following changes:
import jax
import jax.numpy as jnp
import jax.typing as jt
import flax.linen as nn
class MLP(nn.Module):
@nn.compact
def __call__(self, inputs: jt.ArrayLike):
inputs = jnp.array(inputs)
return nn.Dense(4)(inputs)
if __name__ == "__main__":
inputs = jnp.ones((2, 4))
mlp = MLP()
prng_key = jax.random.key(7)
# unpack the output
outputs, _ = mlp.apply(mlp.init(prng_key, inputs), inputs)
# type hint outputs correctly
outputs: jax.Array = outputs # pass the type checker
print(f"type of outputs is {type(outputs)}")
Enjoyed this question?
Check out more content on our blog or follow us on social media.
Browse more questions