import os import fal_client import folder_paths import configparser import base64 import io from PIL import Image import logging import json import requests import numpy as np import torch logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) class BaseFalAPIFluxNode: def __init__(self): self.api_key = self.get_api_key() os.environ['FAL_KEY'] = self.api_key self.api_endpoint = None def get_api_key(self): config = configparser.ConfigParser() config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'config.ini') if os.path.exists(config_path): config.read(config_path) return config.get('falai', 'api_key', fallback=None) return None def set_api_endpoint(self, endpoint): self.api_endpoint = endpoint @classmethod def INPUT_TYPES(cls): return { "required": { "prompt": ("STRING", {"multiline": True}), "width": ("INT", {"default": 1024, "step": 8}), "height": ("INT", {"default": 1024, "step": 8}), "num_inference_steps": ("INT", {"default": 28, "min": 1, "max": 100}), "guidance_scale": ("FLOAT", {"default": 3.5, "min": 0.1, "max": 40.0}), "num_images": ("INT", {"default": 1, "min": 1, "max": 4}), "enable_safety_checker": ("BOOLEAN", {"default": True}), }, "optional": { "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), } } RETURN_TYPES = ("IMAGE",) FUNCTION = "generate" CATEGORY = "image generation" def prepare_arguments(self, prompt, width, height, num_inference_steps, guidance_scale, num_images, enable_safety_checker, seed=None, **kwargs): if not self.api_key: raise ValueError("API key is not set. Please check your config.ini file.") arguments = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "num_images": num_images, "enable_safety_checker": enable_safety_checker } # Handle custom image size if width is None or height is None: raise ValueError("Width and height must be provided when using custom image size") arguments["image_size"] = { "width": width, "height": height } if seed is not None and seed != 0: arguments["seed"] = seed return arguments def call_api(self, arguments): logger.debug(f"Full API request payload: {json.dumps(arguments, indent=2)}") if not self.api_endpoint: raise ValueError("API endpoint is not set. Please set it using set_api_endpoint() method.") try: handler = fal_client.submit( self.api_endpoint, arguments=arguments, ) result = handler.get() logger.debug(f"API response: {json.dumps(result, indent=2)}") return result except Exception as e: logger.error(f"API error details: {str(e)}") if hasattr(e, 'response'): logger.error(f"API error response: {e.response.text}") raise RuntimeError(f"An error occurred when calling the fal.ai API: {str(e)}") from e def process_images(self, result): if "images" not in result or not result["images"]: logger.error("No images were generated by the API.") raise RuntimeError("No images were generated by the API.") output_images = [] for index, img_info in enumerate(result["images"]): try: logger.debug(f"Processing image {index}: {json.dumps(img_info, indent=2)}") if not isinstance(img_info, dict) or "url" not in img_info or not img_info["url"]: logger.error(f"Invalid image info for image {index}") continue img_url = img_info["url"] logger.debug(f"Image URL: {img_url[:100]}...") # Log the first 100 characters of the URL if img_url.startswith("data:image"): # Handle Base64 encoded image try: _, img_data = img_url.split(",", 1) img_data = base64.b64decode(img_data) except ValueError: logger.error(f"Failed to split image URL for image {index}") continue else: # Handle regular URL try: response = requests.get(img_url) response.raise_for_status() img_data = response.content except requests.RequestException as e: logger.error(f"Failed to download image from URL for image {index}: {str(e)}") continue # Log the first few bytes of the image data logger.debug(f"First 20 bytes of image data: {img_data[:20]}") # Try to interpret the data as an image try: img = Image.open(io.BytesIO(img_data)) logger.debug(f"Opened image with size: {img.size} and mode: {img.mode}") except Exception as e: logger.error(f"Failed to open image data: {str(e)}") # If opening as an image fails, try to interpret it as raw pixel data img_np = np.frombuffer(img_data, dtype=np.uint8) logger.debug(f"Interpreted as raw pixel data with shape: {img_np.shape}") # If the shape is (1024,), reshape it to a more sensible image size if img_np.shape == (1024,): img_np = img_np.reshape(32, 32) # Reshape to 32x32 image elif img_np.shape == (1, 1, 1024): img_np = img_np.reshape(32, 32) # Normalize the data to 0-255 range img_np = ((img_np - img_np.min()) / (img_np.max() - img_np.min()) * 255).astype(np.uint8) img = Image.fromarray(img_np, 'L') # Create grayscale image img = img.convert('RGB') # Convert to RGB # Ensure image is in RGB mode if img.mode != 'RGB': img = img.convert('RGB') # Convert PIL Image to NumPy array img_np = np.array(img).astype(np.float32) / 255.0 # Create tensor with batch dimension (1, H, W, C) img_tensor = torch.from_numpy(img_np) img_tensor = img_tensor.unsqueeze(0) # (1, H, W, C) output_images.append(img_tensor) except Exception as e: logger.error(f"Failed to process image {index}: {str(e)}") if not output_images: logger.error("Failed to process any of the generated images.") raise RuntimeError("Failed to process any of the generated images.") # Stack all images into a single batch tensor if output_images: output_tensor = torch.cat(output_images, dim=0) logger.debug(f"Returning batched tensor with shape: {output_tensor.shape}") return [output_tensor] else: logger.error("No images were successfully processed") raise RuntimeError("No images were successfully processed") def upload_image(self, image): try: # Convert PyTorch tensor to numpy array if isinstance(image, torch.Tensor): image = image.cpu().numpy() # Handle different shapes of numpy arrays if isinstance(image, np.ndarray): if image.ndim == 4 and image.shape[0] == 1: # (1, H, W, 3) or (1, H, W, 1) image = image.squeeze(0) if image.ndim == 3: if image.shape[2] == 3: # (H, W, 3) RGB image pass elif image.shape[2] == 1: # (H, W, 1) grayscale image = np.repeat(image, 3, axis=2) elif image.shape[0] == 3: # (3, H, W) RGB image = np.transpose(image, (1, 2, 0)) elif image.shape[0] == 1: # (1, H, W) grayscale image = np.repeat(image.squeeze(0)[..., np.newaxis], 3, axis=2) elif image.shape == (1, 1, 1536): # Special case for (1, 1, 1536) shape image = image.reshape(32, 48) image = np.repeat(image[..., np.newaxis], 3, axis=2) else: raise ValueError(f"Unsupported image shape: {image.shape}") # Normalize to 0-255 range if not already if image.dtype != np.uint8: image = (image - image.min()) / (image.max() - image.min()) * 255 image = image.astype(np.uint8) image = Image.fromarray(image) # Ensure image is in RGB mode if image.mode != 'RGB': image = image.convert('RGB') # Resize image if it's too large (optional, adjust max_size as needed) max_size = 1024 # Example max size if max(image.size) > max_size: image.thumbnail((max_size, max_size), Image.LANCZOS) # Convert PIL Image to bytes buffered = io.BytesIO() image.save(buffered, format="PNG") img_byte = buffered.getvalue() # Upload the image using fal_client url = fal_client.upload(img_byte, "image/png") logger.info(f"Image uploaded successfully. URL: {url}") return url except Exception as e: logger.error(f"Failed to process or upload image: {str(e)}") raise def generate(self, **kwargs): arguments = self.prepare_arguments(**kwargs) result = self.call_api(arguments) output_images = self.process_images(result) return tuple(output_images)