Compare commits
1 Commits
main
...
openvino-c
Author | SHA1 | Date | |
---|---|---|---|
b4e1a5ffae |
58
main.py
58
main.py
@ -1,33 +1,30 @@
|
|||||||
from flask import Flask, request, jsonify
|
from flask import Flask, request, jsonify
|
||||||
from diffusers import DDPMScheduler, DiffusionPipeline
|
from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline
|
||||||
import torch
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
|
# Define paths for OpenVINO IR models
|
||||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler)
|
OV_MODEL_DIR = "echarlaix/stable-diffusion-2-1-openvino" # New directory for the LCM model
|
||||||
|
|
||||||
|
# --- Model Loading ---
|
||||||
|
pipe = None
|
||||||
|
|
||||||
print("Loading UNet weights...")
|
|
||||||
ckpt_path = './models/model.bin'
|
|
||||||
try:
|
try:
|
||||||
unet_state_dict = torch.load(ckpt_path)
|
pipe = OVStableDiffusionXLPipeline.from_pretrained(
|
||||||
pipe.unet.load_state_dict(unet_state_dict)
|
"rupeshs/hyper-sd-sdxl-1-step-openvino-int8",
|
||||||
if torch.cuda.is_available():
|
ov_config={"CACHE_DIR": ""},
|
||||||
pipe = pipe.to("cuda")
|
)
|
||||||
pipe.to(torch.float16)
|
|
||||||
print("UNet weights loaded successfully and model moved to CUDA.")
|
print("Compiling OpenVINO pipeline...")
|
||||||
else:
|
pipe.compile() # Compile the pipeline for the target device (CPU by default)
|
||||||
print("CUDA is not available. Running on CPU (will be slow).")
|
print("OpenVINO compiled successfully.")
|
||||||
pipe = pipe.to("cpu")
|
|
||||||
pipe.to(torch.float32)
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.")
|
|
||||||
exit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while loading UNet weights: {e}")
|
print(f"An error occurred during OpenVINO LCM model loading or conversion: {e}")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
@app.route('/api/generate', methods=['POST'])
|
@app.route('/api/generate', methods=['POST'])
|
||||||
@ -36,23 +33,28 @@ def generate_image():
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
return jsonify({"error": "Prompt is required in the request body."}), 400
|
return jsonify({"error": "Prompt is required in the request body."}), 400
|
||||||
|
|
||||||
print(f"Generating image for prompt: '{prompt}'...")
|
print(f"Generating image for prompt: '{prompt}' using OpenVINO LCM...")
|
||||||
try:
|
try:
|
||||||
image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0]
|
# Crucially, set num_inference_steps=1 for 1-step generation with LCMs
|
||||||
|
# You might even omit guidance_scale for some LCMs, or use a very low value.
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
width=768,
|
||||||
|
height=768,
|
||||||
|
num_inference_steps=1,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
).images[0]
|
||||||
image = image.resize((128, 128), Image.LANCZOS)
|
image = image.resize((128, 128), Image.LANCZOS)
|
||||||
|
|
||||||
# Convert image to base64
|
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless
|
image.save(buffered, format="PNG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
return jsonify({"image": img_str})
|
return jsonify({"image": img_str})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error during image generation: {e}")
|
print(f"Error during image generation with OpenVINO LCM: {e}")
|
||||||
return jsonify({"error": "An error occurred during image generation."}), 500
|
return jsonify({"error": "An error occurred during image generation."}), 500
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# Make sure to run this with 'python app.py'
|
app.run(host='127.0.0.1', port=5000, debug=True)
|
||||||
# It will be accessible at http://127.0.0.1:5000/
|
|
||||||
app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development
|
|
||||||
|
@ -1,39 +1,107 @@
|
|||||||
|
about-time==4.2.1
|
||||||
accelerate==1.8.1
|
accelerate==1.8.1
|
||||||
|
aiohappyeyeballs==2.6.1
|
||||||
|
aiohttp==3.12.13
|
||||||
|
aiosignal==1.4.0
|
||||||
|
alive-progress==3.2.0
|
||||||
|
attrs==25.3.0
|
||||||
|
autograd==1.8.0
|
||||||
blinker==1.9.0
|
blinker==1.9.0
|
||||||
certifi==2025.6.15
|
certifi==2025.6.15
|
||||||
charset-normalizer==3.4.2
|
charset-normalizer==3.4.2
|
||||||
click==8.2.1
|
click==8.2.1
|
||||||
|
cma==4.2.0
|
||||||
|
contourpy==1.3.2
|
||||||
|
cycler==0.12.1
|
||||||
|
datasets==3.6.0
|
||||||
|
Deprecated==1.2.18
|
||||||
diffusers==0.34.0
|
diffusers==0.34.0
|
||||||
|
dill==0.3.8
|
||||||
filelock==3.18.0
|
filelock==3.18.0
|
||||||
Flask==3.1.1
|
Flask==3.1.1
|
||||||
fsspec==2025.5.1
|
fonttools==4.58.5
|
||||||
|
frozenlist==1.7.0
|
||||||
|
fsspec==2025.3.0
|
||||||
|
grapheme==0.6.0
|
||||||
hf-xet==1.1.5
|
hf-xet==1.1.5
|
||||||
huggingface-hub==0.33.2
|
huggingface-hub==0.33.2
|
||||||
idna==3.10
|
idna==3.10
|
||||||
importlib_metadata==8.7.0
|
importlib_metadata==8.7.0
|
||||||
itsdangerous==2.2.0
|
itsdangerous==2.2.0
|
||||||
Jinja2==3.1.6
|
Jinja2==3.1.6
|
||||||
MarkupSafe==2.1.5
|
joblib==1.5.1
|
||||||
|
jsonschema==4.24.0
|
||||||
|
jsonschema-specifications==2025.4.1
|
||||||
|
kiwisolver==1.4.8
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
MarkupSafe==3.0.2
|
||||||
|
matplotlib==3.10.3
|
||||||
|
mdurl==0.1.2
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
networkx==3.5
|
multidict==6.6.3
|
||||||
numpy==2.3.1
|
multiprocess==0.70.16
|
||||||
|
natsort==8.4.0
|
||||||
|
networkx==3.4.2
|
||||||
|
ninja==1.11.1.4
|
||||||
|
nncf==2.17.0
|
||||||
|
numpy==2.2.6
|
||||||
|
nvidia-cublas-cu12==12.6.4.1
|
||||||
|
nvidia-cuda-cupti-cu12==12.6.80
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||||
|
nvidia-cuda-runtime-cu12==12.6.77
|
||||||
|
nvidia-cudnn-cu12==9.5.1.17
|
||||||
|
nvidia-cufft-cu12==11.3.0.4
|
||||||
|
nvidia-cufile-cu12==1.11.1.6
|
||||||
|
nvidia-curand-cu12==10.3.7.77
|
||||||
|
nvidia-cusolver-cu12==11.7.1.2
|
||||||
|
nvidia-cusparse-cu12==12.5.4.2
|
||||||
|
nvidia-cusparselt-cu12==0.6.3
|
||||||
|
nvidia-nccl-cu12==2.26.2
|
||||||
|
nvidia-nvjitlink-cu12==12.6.85
|
||||||
|
nvidia-nvtx-cu12==12.6.77
|
||||||
|
onnx==1.18.0
|
||||||
|
openvino==2025.2.0
|
||||||
|
openvino-telemetry==2025.2.0
|
||||||
|
openvino-tokenizers==2025.2.0.1
|
||||||
|
optimum==1.26.1
|
||||||
|
optimum-intel==1.24.0
|
||||||
packaging==25.0
|
packaging==25.0
|
||||||
pillow==11.2.1
|
pandas==2.2.3
|
||||||
|
pillow==11.3.0
|
||||||
|
propcache==0.3.2
|
||||||
|
protobuf==6.31.1
|
||||||
psutil==7.0.0
|
psutil==7.0.0
|
||||||
pytorch-triton-rocm==3.3.1+gitc8757738
|
pyarrow==20.0.0
|
||||||
|
pydot==3.0.4
|
||||||
|
Pygments==2.19.2
|
||||||
|
pymoo==0.6.1.5
|
||||||
|
pyparsing==3.2.3
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
pytz==2025.2
|
||||||
PyYAML==6.0.2
|
PyYAML==6.0.2
|
||||||
|
referencing==0.36.2
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
requests==2.32.4
|
requests==2.32.4
|
||||||
|
rich==14.0.0
|
||||||
|
rpds-py==0.26.0
|
||||||
safetensors==0.5.3
|
safetensors==0.5.3
|
||||||
setuptools==78.1.0
|
scikit-learn==1.7.0
|
||||||
|
scipy==1.16.0
|
||||||
|
setuptools==80.9.0
|
||||||
|
six==1.17.0
|
||||||
sympy==1.14.0
|
sympy==1.14.0
|
||||||
|
tabulate==0.9.0
|
||||||
|
threadpoolctl==3.6.0
|
||||||
tokenizers==0.21.2
|
tokenizers==0.21.2
|
||||||
torch==2.9.0.dev20250706+rocm6.4
|
torch==2.7.1
|
||||||
torchaudio==2.8.0.dev20250706+rocm6.4
|
|
||||||
torchvision==0.24.0.dev20250706+rocm6.4
|
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
transformers==4.53.1
|
transformers==4.52.4
|
||||||
typing_extensions==4.14.0
|
triton==3.3.1
|
||||||
|
typing_extensions==4.14.1
|
||||||
|
tzdata==2025.2
|
||||||
urllib3==2.5.0
|
urllib3==2.5.0
|
||||||
Werkzeug==3.1.3
|
Werkzeug==3.1.3
|
||||||
|
wrapt==1.17.2
|
||||||
|
xxhash==3.5.0
|
||||||
|
yarl==1.20.1
|
||||||
zipp==3.23.0
|
zipp==3.23.0
|
||||||
|
Reference in New Issue
Block a user