✨ | Use SDXL from rupeshs instead
This commit is contained in:
58
main.py
58
main.py
@ -1,33 +1,30 @@
|
||||
from flask import Flask, request, jsonify
|
||||
from diffusers import DDPMScheduler, DiffusionPipeline
|
||||
import torch
|
||||
from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler)
|
||||
# Define paths for OpenVINO IR models
|
||||
OV_MODEL_DIR = "echarlaix/stable-diffusion-2-1-openvino" # New directory for the LCM model
|
||||
|
||||
# --- Model Loading ---
|
||||
pipe = None
|
||||
|
||||
print("Loading UNet weights...")
|
||||
ckpt_path = './models/model.bin'
|
||||
try:
|
||||
unet_state_dict = torch.load(ckpt_path)
|
||||
pipe.unet.load_state_dict(unet_state_dict)
|
||||
if torch.cuda.is_available():
|
||||
pipe = pipe.to("cuda")
|
||||
pipe.to(torch.float16)
|
||||
print("UNet weights loaded successfully and model moved to CUDA.")
|
||||
else:
|
||||
print("CUDA is not available. Running on CPU (will be slow).")
|
||||
pipe = pipe.to("cpu")
|
||||
pipe.to(torch.float32)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.")
|
||||
exit()
|
||||
pipe = OVStableDiffusionXLPipeline.from_pretrained(
|
||||
"rupeshs/hyper-sd-sdxl-1-step-openvino-int8",
|
||||
ov_config={"CACHE_DIR": ""},
|
||||
)
|
||||
|
||||
print("Compiling OpenVINO pipeline...")
|
||||
pipe.compile() # Compile the pipeline for the target device (CPU by default)
|
||||
print("OpenVINO compiled successfully.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred while loading UNet weights: {e}")
|
||||
print(f"An error occurred during OpenVINO LCM model loading or conversion: {e}")
|
||||
exit()
|
||||
|
||||
@app.route('/api/generate', methods=['POST'])
|
||||
@ -36,23 +33,28 @@ def generate_image():
|
||||
if not prompt:
|
||||
return jsonify({"error": "Prompt is required in the request body."}), 400
|
||||
|
||||
print(f"Generating image for prompt: '{prompt}'...")
|
||||
print(f"Generating image for prompt: '{prompt}' using OpenVINO LCM...")
|
||||
try:
|
||||
image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0]
|
||||
# Crucially, set num_inference_steps=1 for 1-step generation with LCMs
|
||||
# You might even omit guidance_scale for some LCMs, or use a very low value.
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
width=768,
|
||||
height=768,
|
||||
num_inference_steps=1,
|
||||
guidance_scale=1.0,
|
||||
).images[0]
|
||||
image = image.resize((128, 128), Image.LANCZOS)
|
||||
|
||||
# Convert image to base64
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||
|
||||
return jsonify({"image": img_str})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during image generation: {e}")
|
||||
print(f"Error during image generation with OpenVINO LCM: {e}")
|
||||
return jsonify({"error": "An error occurred during image generation."}), 500
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Make sure to run this with 'python app.py'
|
||||
# It will be accessible at http://127.0.0.1:5000/
|
||||
app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development
|
||||
app.run(host='127.0.0.1', port=5000, debug=True)
|
||||
|
Reference in New Issue
Block a user