JAX class member or argument of a function

In JAX, I changed a constant variable from a class member to a function argument, observing a different behavior. Here is the detail:
- I have a function in JAX that has a class object
Obj
as an input. So the execution of the function is something likefunc(Obj, ...)
, where...
is some other inputs. - I had a constant variable as a member of a class, like
Obj.var
. I used pytree to handle the class, as suggested in JAX documentation. I madeObj.var
static when registering the pytree because I don't need to change it. - Recently, I removed
var
from the class and made it an argument of a function, likef(Obj, var, ...)
. I didn't makevar
static because I may want to batch over it. - Then, the result of the function changes.
Is this something expected? What is the good practice to handle a constant value like this?
Answer
Yes, behavior difference make sense in JAX.as When var was a static class member (in the pytree), JAX treats it as a "compile-time constant" - it gets baked into any JIT-compiled code and isn't part of the differentiable computation graph.
But when you move it to a function argument, JAX now sees it as a "dynamic value" that can change between calls by which It becomes part of the computation graph; Gradients can flow through it;gets subject to transformations (like vmap)
u can Use class members for true constants that never change or Use function arguments for values you might differentiate through or batch over
Enjoyed this question?
Check out more content on our blog or follow us on social media.
Browse more questions