Flux TE Neutered
A text encoder designed to serve as a lightweight alternative to the T5 model. For use with the Flux Dev model. Trained in the shadow of a giant.
While the model can still follow prompts to some extent, its performance in this area is noticeably worse than the original text encoder.
Download
Download the T5 variant, the optional preview decoder and the FP8 quant.
Lightweight requirements:
text_encoder.py
text_encoder_2.safetensors
Setup
pip install accelerate diffusers einops optimum-quanto protobuf sentencepiece transformers
Inference
from diffusers import FluxPipeline, FluxTransformer2DModel
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
import torch
from text_encoder import t5_config, T5EncoderModel, PretrainedTextEncoder
class Flux2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
if __name__ == '__main__':
t5 = PretrainedTextEncoder(t5_config, T5EncoderModel(t5_config)).to(dtype=torch.float16)
t5.load_model('text_encoder_2.safetensors')
# transformer = Flux2DModel.from_pretrained('./flux-fp8')._wrapped
pipe = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
text_encoder_2=t5,
# transformer=transformer
)
pipe.enable_model_cpu_offload()
image = pipe('a black cat wearing a Pikachu cosplay', num_inference_steps=10, output_type='pil').images[0]
image.save('cat.png')
The example saves a preview image.
from diffusers import FluxPipeline, FluxTransformer2DModel
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
import torch
from tea_model import TeaDecoder
from text_encoder import t5_config, T5EncoderModel, PretrainedTextEncoder
class Flux2DModel(QuantizedDiffusersModel):
base_class = FluxTransformer2DModel
def preview_image(pipe, latents):
latents = FluxPipeline._unpack_latents(latents,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.default_sample_size * pipe.vae_scale_factor,
pipe.vae_scale_factor)
tea = TeaDecoder(ch_in=16)
load_model(tea, './vae_decoder.safetensors')
tea = tea.to(device='cuda')
output = tea(latents.to(torch.float32)) / 2.0 + 0.5
preview = transforms.ToPILImage()(output[0].clamp(0, 1))
return preview
if __name__ == '__main__':
t5 = PretrainedTextEncoder(t5_config, T5EncoderModel(t5_config)).to(dtype=torch.float16)
t5.load_model('text_encoder_2.safetensors')
# transformer = Flux2DModel.from_pretrained('./flux-fp8')._wrapped
pipe = FluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev',
text_encoder_2=t5,
# transformer=transformer
)
pipe.enable_model_cpu_offload()
latents = pipe('cat playing piano', num_inference_steps=10, output_type='latent').images
preview = preview_image(pipe, latents)
preview.save('cat.png')
Disclaimer
Use of this code and the model requires citation and attribution to the author via a link to their Hugging Face profile in all resulting work.
Model tree for twodgirl/flux-text-encoder-neutered
Base model
black-forest-labs/FLUX.1-dev
Finetuned
this model