From de05b0824f86a055c567f191955ef290cd90cd19 Mon Sep 17 00:00:00 2001 From: NikkeDoy Date: Mon, 7 Jul 2025 21:20:45 +0300 Subject: [PATCH] :tada: | Project added to Git --- main.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 39 ++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 main.py create mode 100644 requirements.txt diff --git a/main.py b/main.py new file mode 100644 index 0000000..4d40197 --- /dev/null +++ b/main.py @@ -0,0 +1,58 @@ +from flask import Flask, request, jsonify +from diffusers import DDPMScheduler, DiffusionPipeline +import torch +import base64 +from io import BytesIO +from PIL import Image + +app = Flask(__name__) + +scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", scheduler=scheduler) + +print("Loading UNet weights...") +ckpt_path = './models/model.bin' +try: + unet_state_dict = torch.load(ckpt_path) + pipe.unet.load_state_dict(unet_state_dict) + if torch.cuda.is_available(): + pipe = pipe.to("cuda") + pipe.to(torch.float16) + print("UNet weights loaded successfully and model moved to CUDA.") + else: + print("CUDA is not available. Running on CPU (will be slow).") + pipe = pipe.to("cpu") + pipe.to(torch.float32) +except FileNotFoundError: + print(f"Error: Model not found at {ckpt_path}. Please ensure the file exists.") + exit() +except Exception as e: + print(f"An error occurred while loading UNet weights: {e}") + exit() + +@app.route('/api/generate', methods=['POST']) +def generate_image(): + prompt = request.json.get('prompt') + if not prompt: + return jsonify({"error": "Prompt is required in the request body."}), 400 + + print(f"Generating image for prompt: '{prompt}'...") + try: + image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[999]).images[0] + image = image.resize((128, 128), Image.LANCZOS) + + # Convert image to base64 + buffered = BytesIO() + image.save(buffered, format="JPEG") # You can choose JPEG for smaller size, but PNG is lossless + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return jsonify({"image": img_str}) + + except Exception as e: + print(f"Error during image generation: {e}") + return jsonify({"error": "An error occurred during image generation."}), 500 + +if __name__ == '__main__': + # Make sure to run this with 'python app.py' + # It will be accessible at http://127.0.0.1:5000/ + app.run(host='127.0.0.1', port=5000, debug=True) # debug=True is good for development diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fe6f082 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,39 @@ +accelerate==1.8.1 +blinker==1.9.0 +certifi==2025.6.15 +charset-normalizer==3.4.2 +click==8.2.1 +diffusers==0.34.0 +filelock==3.18.0 +Flask==3.1.1 +fsspec==2025.5.1 +hf-xet==1.1.5 +huggingface-hub==0.33.2 +idna==3.10 +importlib_metadata==8.7.0 +itsdangerous==2.2.0 +Jinja2==3.1.6 +MarkupSafe==2.1.5 +mpmath==1.3.0 +networkx==3.5 +numpy==2.3.1 +packaging==25.0 +pillow==11.2.1 +psutil==7.0.0 +pytorch-triton-rocm==3.3.1+gitc8757738 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.4 +safetensors==0.5.3 +setuptools==78.1.0 +sympy==1.14.0 +tokenizers==0.21.2 +torch==2.9.0.dev20250706+rocm6.4 +torchaudio==2.8.0.dev20250706+rocm6.4 +torchvision==0.24.0.dev20250706+rocm6.4 +tqdm==4.67.1 +transformers==4.53.1 +typing_extensions==4.14.0 +urllib3==2.5.0 +Werkzeug==3.1.3 +zipp==3.23.0