In JAX, I changed a constant variable from a class member to a function argument, observing a different behavior. Here is the detail:
- I have a function in JAX that has a class object
Obj
as an input. So the execution of the function is something likefunc(Obj, ...)
, where...
is some other inputs. - I had a constant variable as a member of a class, like
Obj.var
. I used pytree to handle the class, as suggested in JAX documentation. I madeObj.var
static when registering the pytree because I don't need to change it. - Recently, I removed
var
from the class and made it an argument of a function, likef(Obj, var, ...)
. I didn't makevar
static because I may want to batch over it. - Then, the result of the function changes.
Is this something expected? What is the good practice to handle a constant value like this?
Answer
Yes, behavior difference make sense in JAX.as When var was a static class member (in the pytree), JAX treats it as a "compile-time constant" - it gets baked into any JIT-compiled code and isn't part of the differentiable computation graph.
But when you move it to a function argument, JAX now sees it as a "dynamic value" that can change between calls by which It becomes part of the computation graph; Gradients can flow through it;gets subject to transformations (like vmap)
u can Use class members for true constants that never change or Use function arguments for values you might differentiate through or batch over