From b4e1a5ffae9f4d037156bffa6623d2c35624ad84 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Mon, 7 Jul 2025 21:51:59 +0300 Subject: [PATCH] :sparkles: | Use SDXL from rupeshs instead --- main.py | 58 +++++++++++++++--------------- requirements.txt | 92 +++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 40 deletions(-) diff --git a/main.py b/main.py index 4d40197..5c02fbb 100644 --- a/main.py +++ b/main.py @@ -1,33 +1,30 @@ from flask import Flask, request, jsonify -from diffusers import DDPMScheduler, DiffusionPipeline -import torch +from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionXLPipeline import base64 from io import BytesIO from PIL import Image +import os app = Flask(__name__) -scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") -pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler) +# Define paths for OpenVINO IR models +OV_MODEL_DIR = "echarlaix/stable-diffusion-2-1-openvino" # New directory for the LCM model + +# --- Model Loading --- +pipe = None -print("Loading UNet weights...") -ckpt_path = './models/model.bin' try: - unet_state_dict = torch.load(ckpt_path) - pipe.unet.load_state_dict(unet_state_dict) - if torch.cuda.is_available(): - pipe = pipe.to("cuda") - pipe.to(torch.float16) - print("UNet weights loaded successfully and model moved to CUDA.") - else: - print("CUDA is not available. Running on CPU (will be slow).") - pipe = pipe.to("cpu") - pipe.to(torch.float32) -except FileNotFoundError: - print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.") - exit() + pipe = OVStableDiffusionXLPipeline.from_pretrained( + "rupeshs/hyper-sd-sdxl-1-step-openvino-int8", + ov_config={"CACHE_DIR": ""}, + ) + + print("Compiling OpenVINO pipeline...") + pipe.compile() # Compile the pipeline for the target device (CPU by default) + print("OpenVINO compiled successfully.") + except Exception as e: - print(f"An error occurred while loading UNet weights: {e}") + print(f"An error occurred during OpenVINO LCM model loading or conversion: {e}") exit() @app.route('/api/generate', methods=['POST']) @@ -36,23 +33,28 @@ def generate_image(): if not prompt: return jsonify({"error": "Prompt is required in the request body."}), 400 - print(f"Generating image for prompt: '{prompt}'...") + print(f"Generating image for prompt: '{prompt}' using OpenVINO LCM...") try: - image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0] + # Crucially, set num_inference_steps=1 for 1-step generation with LCMs + # You might even omit guidance_scale for some LCMs, or use a very low value. + image = pipe( + prompt=prompt, + width=768, + height=768, + num_inference_steps=1, + guidance_scale=1.0, + ).images[0] image = image.resize((128, 128), Image.LANCZOS) - # Convert image to base64 buffered = BytesIO() - image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless + image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return jsonify({"image": img_str}) except Exception as e: - print(f"Error during image generation: {e}") + print(f"Error during image generation with OpenVINO LCM: {e}") return jsonify({"error": "An error occurred during image generation."}), 500 if __name__ == '__main__': - # Make sure to run this with 'python app.py' - # It will be accessible at http://127.0.0.1:5000/ - app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development + app.run(host='127.0.0.1', port=5000, debug=True) diff --git a/requirements.txt b/requirements.txt index fe6f082..29ae68b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,39 +1,107 @@ +about-time==4.2.1 accelerate==1.8.1 +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiosignal==1.4.0 +alive-progress==3.2.0 +attrs==25.3.0 +autograd==1.8.0 blinker==1.9.0 certifi==2025.6.15 charset-normalizer==3.4.2 click==8.2.1 +cma==4.2.0 +contourpy==1.3.2 +cycler==0.12.1 +datasets==3.6.0 +Deprecated==1.2.18 diffusers==0.34.0 +dill==0.3.8 filelock==3.18.0 Flask==3.1.1 -fsspec==2025.5.1 +fonttools==4.58.5 +frozenlist==1.7.0 +fsspec==2025.3.0 +grapheme==0.6.0 hf-xet==1.1.5 huggingface-hub==0.33.2 idna==3.10 importlib_metadata==8.7.0 itsdangerous==2.2.0 Jinja2==3.1.6 -MarkupSafe==2.1.5 +joblib==1.5.1 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +kiwisolver==1.4.8 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.3 +mdurl==0.1.2 mpmath==1.3.0 -networkx==3.5 -numpy==2.3.1 +multidict==6.6.3 +multiprocess==0.70.16 +natsort==8.4.0 +networkx==3.4.2 +ninja==1.11.1.4 +nncf==2.17.0 +numpy==2.2.6 +nvidia-cublas-cu12==12.6.4.1 +nvidia-cuda-cupti-cu12==12.6.80 +nvidia-cuda-nvrtc-cu12==12.6.77 +nvidia-cuda-runtime-cu12==12.6.77 +nvidia-cudnn-cu12==9.5.1.17 +nvidia-cufft-cu12==11.3.0.4 +nvidia-cufile-cu12==1.11.1.6 +nvidia-curand-cu12==10.3.7.77 +nvidia-cusolver-cu12==11.7.1.2 +nvidia-cusparse-cu12==12.5.4.2 +nvidia-cusparselt-cu12==0.6.3 +nvidia-nccl-cu12==2.26.2 +nvidia-nvjitlink-cu12==12.6.85 +nvidia-nvtx-cu12==12.6.77 +onnx==1.18.0 +openvino==2025.2.0 +openvino-telemetry==2025.2.0 +openvino-tokenizers==2025.2.0.1 +optimum==1.26.1 +optimum-intel==1.24.0 packaging==25.0 -pillow==11.2.1 +pandas==2.2.3 +pillow==11.3.0 +propcache==0.3.2 +protobuf==6.31.1 psutil==7.0.0 -pytorch-triton-rocm==3.3.1+gitc8757738 +pyarrow==20.0.0 +pydot==3.0.4 +Pygments==2.19.2 +pymoo==0.6.1.5 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytz==2025.2 PyYAML==6.0.2 +referencing==0.36.2 regex==2024.11.6 requests==2.32.4 +rich==14.0.0 +rpds-py==0.26.0 safetensors==0.5.3 -setuptools==78.1.0 +scikit-learn==1.7.0 +scipy==1.16.0 +setuptools==80.9.0 +six==1.17.0 sympy==1.14.0 +tabulate==0.9.0 +threadpoolctl==3.6.0 tokenizers==0.21.2 -torch==2.9.0.dev20250706+rocm6.4 -torchaudio==2.8.0.dev20250706+rocm6.4 -torchvision==0.24.0.dev20250706+rocm6.4 +torch==2.7.1 tqdm==4.67.1 -transformers==4.53.1 -typing_extensions==4.14.0 +transformers==4.52.4 +triton==3.3.1 +typing_extensions==4.14.1 +tzdata==2025.2 urllib3==2.5.0 Werkzeug==3.1.3 +wrapt==1.17.2 +xxhash==3.5.0 +yarl==1.20.1 zipp==3.23.0