For a long time now since I started playing around with AI Art creation, I have not had much success with JAX and have in fact given up on it after trying several times. I found it hard to understand and use, kept crashing often even when I have faster GPUs allocated.
However, as I kept seeing some great art being created by people using JAX, I finally reached out to one of the experts in JAX on Twitter and he kindly obliged to help me via Discord private chat. In this post, I hope to share my findings based on @huemen’s help (make sure you follow him) and document them here so others do not feel the pain/regret/hesitation to get started with JAX Diffusion.
Selecting the notebook to use, there are two popular ones, one is by NeuralismAI and another by Huemen. I use the later and therefore my suggestions and tips are based on this notebook. They should work the same way on the other notebook but I have not tried it (yet).
Tips and Suggestions
The standard defaults in this notebook don’t produced good results (at the time of writing this post), so I was suggested by Huemen the following settings:
- choose_diffusion_model: cc12m
- use_vitb16 and use_vitb32 are ticked/selected (using use_vitl14 caused my runs to crash..probably not enough GPU RAM..I’m guessing)
- image_size: (768, 576) for landscape or (576, 768) for portrait image (anything bigger again failed to run for me)
- batch_size: 1
- n_batches: 5 (this gives me 5x different images from one prompt, you can set to 1 for experiments)
- clip_guidance_scale: 100000 (Huemen suggested 80-100k value)
- tv_scale: -1 (creates sharper images for me)
- range_scale: -1 (seems to add more contrast)
Heumen also suggested that I try and tweak other settings like mean_scale and var_scale but I have not experimented a lot with these.
There is a public document available that provides more details about the parameters and is a good reference once you get this going.