Compare commits

..

1 Commits

Author SHA1 Message Date
b4e1a5ffae | Use SDXL from rupeshs instead 2025-07-07 21:51:59 +03:00
2 changed files with 110 additions and 40 deletions

58
main.py
View File

@ -1,33 +1,30 @@
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from diffusers import DDPMScheduler, DiffusionPipeline from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline
import torch
import base64 import base64
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
import os
app = Flask(__name__) app = Flask(__name__)
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") # Define paths for OpenVINO IR models
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler) OV_MODEL_DIR = "echarlaix/stable-diffusion-2-1-openvino" # New directory for the LCM model
# --- Model Loading ---
pipe = None
print("Loading UNet weights...")
ckpt_path = './models/model.bin'
try: try:
unet_state_dict = torch.load(ckpt_path) pipe = OVStableDiffusionXLPipeline.from_pretrained(
pipe.unet.load_state_dict(unet_state_dict) "rupeshs/hyper-sd-sdxl-1-step-openvino-int8",
if torch.cuda.is_available(): ov_config={"CACHE_DIR": ""},
pipe = pipe.to("cuda") )
pipe.to(torch.float16)
print("UNet weights loaded successfully and model moved to CUDA.") print("Compiling OpenVINO pipeline...")
else: pipe.compile() # Compile the pipeline for the target device (CPU by default)
print("CUDA is not available. Running on CPU (will be slow).") print("OpenVINO compiled successfully.")
pipe = pipe.to("cpu")
pipe.to(torch.float32)
except FileNotFoundError:
print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.")
exit()
except Exception as e: except Exception as e:
print(f"An error occurred while loading UNet weights: {e}") print(f"An error occurred during OpenVINO LCM model loading or conversion: {e}")
exit() exit()
@app.route('/api/generate', methods=['POST']) @app.route('/api/generate', methods=['POST'])
@ -36,23 +33,28 @@ def generate_image():
if not prompt: if not prompt:
return jsonify({"error": "Prompt is required in the request body."}), 400 return jsonify({"error": "Prompt is required in the request body."}), 400
print(f"Generating image for prompt: '{prompt}'...") print(f"Generating image for prompt: '{prompt}' using OpenVINO LCM...")
try: try:
image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0] # Crucially, set num_inference_steps=1 for 1-step generation with LCMs
# You might even omit guidance_scale for some LCMs, or use a very low value.
image = pipe(
prompt=prompt,
width=768,
height=768,
num_inference_steps=1,
guidance_scale=1.0,
).images[0]
image = image.resize((128, 128), Image.LANCZOS) image = image.resize((128, 128), Image.LANCZOS)
# Convert image to base64
buffered = BytesIO() buffered = BytesIO()
image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"image": img_str}) return jsonify({"image": img_str})
except Exception as e: except Exception as e:
print(f"Error during image generation: {e}") print(f"Error during image generation with OpenVINO LCM: {e}")
return jsonify({"error": "An error occurred during image generation."}), 500 return jsonify({"error": "An error occurred during image generation."}), 500
if __name__ == '__main__': if __name__ == '__main__':
# Make sure to run this with 'python app.py' app.run(host='127.0.0.1', port=5000, debug=True)
# It will be accessible at http://127.0.0.1:5000/
app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development

View File

@ -1,39 +1,107 @@
about-time==4.2.1
accelerate==1.8.1 accelerate==1.8.1
aiohappyeyeballs==2.6.1
aiohttp==3.12.13
aiosignal==1.4.0
alive-progress==3.2.0
attrs==25.3.0
autograd==1.8.0
blinker==1.9.0 blinker==1.9.0
certifi==2025.6.15 certifi==2025.6.15
charset-normalizer==3.4.2 charset-normalizer==3.4.2
click==8.2.1 click==8.2.1
cma==4.2.0
contourpy==1.3.2
cycler==0.12.1
datasets==3.6.0
Deprecated==1.2.18
diffusers==0.34.0 diffusers==0.34.0
dill==0.3.8
filelock==3.18.0 filelock==3.18.0
Flask==3.1.1 Flask==3.1.1
fsspec==2025.5.1 fonttools==4.58.5
frozenlist==1.7.0
fsspec==2025.3.0
grapheme==0.6.0
hf-xet==1.1.5 hf-xet==1.1.5
huggingface-hub==0.33.2 huggingface-hub==0.33.2
idna==3.10 idna==3.10
importlib_metadata==8.7.0 importlib_metadata==8.7.0
itsdangerous==2.2.0 itsdangerous==2.2.0
Jinja2==3.1.6 Jinja2==3.1.6
MarkupSafe==2.1.5 joblib==1.5.1
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
kiwisolver==1.4.8
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.10.3
mdurl==0.1.2
mpmath==1.3.0 mpmath==1.3.0
networkx==3.5 multidict==6.6.3
numpy==2.3.1 multiprocess==0.70.16
natsort==8.4.0
networkx==3.4.2
ninja==1.11.1.4
nncf==2.17.0
numpy==2.2.6
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
onnx==1.18.0
openvino==2025.2.0
openvino-telemetry==2025.2.0
openvino-tokenizers==2025.2.0.1
optimum==1.26.1
optimum-intel==1.24.0
packaging==25.0 packaging==25.0
pillow==11.2.1 pandas==2.2.3
pillow==11.3.0
propcache==0.3.2
protobuf==6.31.1
psutil==7.0.0 psutil==7.0.0
pytorch-triton-rocm==3.3.1+gitc8757738 pyarrow==20.0.0
pydot==3.0.4
Pygments==2.19.2
pymoo==0.6.1.5
pyparsing==3.2.3
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.2 PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6 regex==2024.11.6
requests==2.32.4 requests==2.32.4
rich==14.0.0
rpds-py==0.26.0
safetensors==0.5.3 safetensors==0.5.3
setuptools==78.1.0 scikit-learn==1.7.0
scipy==1.16.0
setuptools==80.9.0
six==1.17.0
sympy==1.14.0 sympy==1.14.0
tabulate==0.9.0
threadpoolctl==3.6.0
tokenizers==0.21.2 tokenizers==0.21.2
torch==2.9.0.dev20250706+rocm6.4 torch==2.7.1
torchaudio==2.8.0.dev20250706+rocm6.4
torchvision==0.24.0.dev20250706+rocm6.4
tqdm==4.67.1 tqdm==4.67.1
transformers==4.53.1 transformers==4.52.4
typing_extensions==4.14.0 triton==3.3.1
typing_extensions==4.14.1
tzdata==2025.2
urllib3==2.5.0 urllib3==2.5.0
Werkzeug==3.1.3 Werkzeug==3.1.3
wrapt==1.17.2
xxhash==3.5.0
yarl==1.20.1
zipp==3.23.0 zipp==3.23.0