JAX class member or argument of a function

JAX class member or argument of a function
python
Ethan Jackson

In JAX, I changed a constant variable from a class member to a function argument, observing a different behavior. Here is the detail:

  1. I have a function in JAX that has a class object Obj as an input. So the execution of the function is something like func(Obj, ...), where ... is some other inputs.
  2. I had a constant variable as a member of a class, like Obj.var. I used pytree to handle the class, as suggested in JAX documentation. I made Obj.var static when registering the pytree because I don't need to change it.
  3. Recently, I removed var from the class and made it an argument of a function, like f(Obj, var, ...). I didn't make var static because I may want to batch over it.
  4. Then, the result of the function changes.

Is this something expected? What is the good practice to handle a constant value like this?

Answer

Yes, behavior difference make sense in JAX.as When var was a static class member (in the pytree), JAX treats it as a "compile-time constant" - it gets baked into any JIT-compiled code and isn't part of the differentiable computation graph.

But when you move it to a function argument, JAX now sees it as a "dynamic value" that can change between calls by which It becomes part of the computation graph; Gradients can flow through it;gets subject to transformations (like vmap)

u can Use class members for true constants that never change or Use function arguments for values you might differentiate through or batch over

Related Articles