import gradio as gr from PIL import Image from pathlib import Path from transformers import CLIPTokenizer import torch from nanograd.models.stable_diffusion import model_loader, pipeline DEVICE = "cpu" ALLOW_CUDA = False ALLOW_MPS = False if torch.cuda.is_available() and ALLOW_CUDA: DEVICE = "cuda" elif torch.backends.mps.is_available() and ALLOW_MPS: DEVICE = "mps" print(f"Using device: {DEVICE}") tokenizer_vocab_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_vocab.json") tokenizer_merges_path = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\tokenizer_merges.txt") model_file = Path("C:\\Users\\Esmail\\Desktop\\nanograd\\nanograd\\models\\stable_diffusion\\sd_data\\v1-5-pruned-emaonly.ckpt") tokenizer = CLIPTokenizer(str(tokenizer_vocab_path), merges_file=str(tokenizer_merges_path)) models = model_loader.preload_models_from_standard_weights(str(model_file), DEVICE) def generate_image(prompt, cfg_scale, num_inference_steps, sampler): uncond_prompt = "" do_cfg = True input_image = None strength = 0.9 seed = 42 output_image = pipeline.generate( prompt=prompt, uncond_prompt=uncond_prompt, input_image=input_image, strength=strength, do_cfg=do_cfg, cfg_scale=cfg_scale, sampler_name=sampler, n_inference_steps=num_inference_steps, seed=seed, models=models, device=DEVICE, idle_device="cpu", tokenizer=tokenizer, ) output_image = Image.fromarray(output_image) return output_image # Gradio interface def gradio_interface(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=2): prompt_input = gr.Textbox(label="Prompt", placeholder="A cat stretching on the floor, highly detailed, ultra sharp, cinematic, 100mm lens, 8k resolution") cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, value=7, step=1) num_inference_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=100, value=20, step=5) sampler = gr.Radio(label="Sampling Method", choices=["ddpm", "Euler a", "Euler", "LMS", "Heun", "DPM2 a", "PLMS"], value="ddpm") generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=2): output_image = gr.Image(label="Output", show_label=False, height=512, width=512) generate_btn.click(fn=generate_image, inputs=[prompt_input, cfg_scale, num_inference_steps, sampler], outputs=output_image) demo.launch() if __name__ == "__main__": gradio_interface()