jax==0.3.25
jaxlib==0.3.25
flax==0.6.3
optax>=0.1.4
Pillow>=9.4.0
numpy==1.23.5

[:python_version < "3.11"]
tomli

[dev]
black
isort
pip-tools
pytest
