| Use SDXL from rupeshs instead

This commit is contained in:
2025-07-07 21:51:59 +03:00
parent 9fefd0ff9f
commit b4e1a5ffae
2 changed files with 110 additions and 40 deletions

58
main.py
View File

@ -1,33 +1,30 @@
from flask import Flask, request, jsonify
from diffusers import DDPMScheduler, DiffusionPipeline
import torch
from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline
import base64
from io import BytesIO
from PIL import Image
import os
app = Flask(__name__)
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler)
# Define paths for OpenVINO IR models
OV_MODEL_DIR = "echarlaix/stable-diffusion-2-1-openvino" # New directory for the LCM model
# --- Model Loading ---
pipe = None
print("Loading UNet weights...")
ckpt_path = './models/model.bin'
try:
unet_state_dict = torch.load(ckpt_path)
pipe.unet.load_state_dict(unet_state_dict)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
pipe.to(torch.float16)
print("UNet weights loaded successfully and model moved to CUDA.")
else:
print("CUDA is not available. Running on CPU (will be slow).")
pipe = pipe.to("cpu")
pipe.to(torch.float32)
except FileNotFoundError:
print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.")
exit()
pipe = OVStableDiffusionXLPipeline.from_pretrained(
"rupeshs/hyper-sd-sdxl-1-step-openvino-int8",
ov_config={"CACHE_DIR": ""},
)
print("Compiling OpenVINO pipeline...")
pipe.compile() # Compile the pipeline for the target device (CPU by default)
print("OpenVINO compiled successfully.")
except Exception as e:
print(f"An error occurred while loading UNet weights: {e}")
print(f"An error occurred during OpenVINO LCM model loading or conversion: {e}")
exit()
@app.route('/api/generate', methods=['POST'])
@ -36,23 +33,28 @@ def generate_image():
if not prompt:
return jsonify({"error": "Prompt is required in the request body."}), 400
print(f"Generating image for prompt: '{prompt}'...")
print(f"Generating image for prompt: '{prompt}' using OpenVINO LCM...")
try:
image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0]
# Crucially, set num_inference_steps=1 for 1-step generation with LCMs
# You might even omit guidance_scale for some LCMs, or use a very low value.
image = pipe(
prompt=prompt,
width=768,
height=768,
num_inference_steps=1,
guidance_scale=1.0,
).images[0]
image = image.resize((128, 128), Image.LANCZOS)
# Convert image to base64
buffered = BytesIO()
image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"image": img_str})
except Exception as e:
print(f"Error during image generation: {e}")
print(f"Error during image generation with OpenVINO LCM: {e}")
return jsonify({"error": "An error occurred during image generation."}), 500
if __name__ == '__main__':
# Make sure to run this with 'python app.py'
# It will be accessible at http://127.0.0.1:5000/
app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development
app.run(host='127.0.0.1', port=5000, debug=True)

View File

@ -1,39 +1,107 @@
about-time==4.2.1
accelerate==1.8.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.13
aiosignal==1.4.0
alive-progress==3.2.0
attrs==25.3.0
autograd==1.8.0
blinker==1.9.0
certifi==2025.6.15
charset-normalizer==3.4.2
click==8.2.1
cma==4.2.0
contourpy==1.3.2
cycler==0.12.1
datasets==3.6.0
Deprecated==1.2.18
diffusers==0.34.0
dill==0.3.8
filelock==3.18.0
Flask==3.1.1
fsspec==2025.5.1
fonttools==4.58.5
frozenlist==1.7.0
fsspec==2025.3.0
grapheme==0.6.0
hf-xet==1.1.5
huggingface-hub==0.33.2
idna==3.10
importlib_metadata==8.7.0
itsdangerous==2.2.0
Jinja2==3.1.6
MarkupSafe==2.1.5
joblib==1.5.1
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
kiwisolver==1.4.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.3
mdurl==0.1.2
mpmath==1.3.0
networkx==3.5
numpy==2.3.1
multidict==6.6.3
multiprocess==0.70.16
natsort==8.4.0
networkx==3.4.2
ninja==1.11.1.4
nncf==2.17.0
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
onnx==1.18.0
openvino==2025.2.0
openvino-telemetry==2025.2.0
openvino-tokenizers==2025.2.0.1
optimum==1.26.1
optimum-intel==1.24.0
packaging==25.0
pillow==11.2.1
pandas==2.2.3
pillow==11.3.0
propcache==0.3.2
protobuf==6.31.1
psutil==7.0.0
pytorch-triton-rocm==3.3.1+gitc8757738
pyarrow==20.0.0
pydot==3.0.4
Pygments==2.19.2
pymoo==0.6.1.5
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.4
rich==14.0.0
rpds-py==0.26.0
safetensors==0.5.3
setuptools==78.1.0
scikit-learn==1.7.0
scipy==1.16.0
setuptools==80.9.0
six==1.17.0
sympy==1.14.0
tabulate==0.9.0
threadpoolctl==3.6.0
tokenizers==0.21.2
torch==2.9.0.dev20250706+rocm6.4
torchaudio==2.8.0.dev20250706+rocm6.4
torchvision==0.24.0.dev20250706+rocm6.4
torch==2.7.1
tqdm==4.67.1
transformers==4.53.1
typing_extensions==4.14.0
transformers==4.52.4
triton==3.3.1
typing_extensions==4.14.1
tzdata==2025.2
urllib3==2.5.0
Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.20.1
zipp==3.23.0