diff --git a/src/arroyosas/app/db_replay_sim_cli.py b/src/arroyosas/app/db_replay_sim_cli.py deleted file mode 100644 index bcf54c5..0000000 --- a/src/arroyosas/app/db_replay_sim_cli.py +++ /dev/null @@ -1,305 +0,0 @@ -import asyncio -import logging -import os -import aiosqlite -from datetime import datetime - -import typer -import zmq -import zmq.asyncio -import msgpack -from tiled.client import from_uri - -from ..config import settings -from ..schemas import ( - RawFrameEvent, - SASStart, - SASStop, - SerializableNumpyArrayModel, -) - -""" -Simulates image retrieval by reading Tiled URLs from a local SQLite database -and sends the fetched images onto ZMQ. -""" - -# Get configuration from environment variables with defaults -DEFAULT_DB_PATH = os.getenv("DB_PATH", "latent_vectors.db") -DEFAULT_API_KEY = os.getenv("TILED_API_KEY", None) -# New environment variable for selecting prod vs dev environment -TILED_ENV = os.getenv("TILED_ENV", "dev").lower() - -# Define environment-specific URLs -TILED_URLS = { - "dev": { - "url_pattern": "http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/" - }, - "prod": { - "url_pattern": "http://tiled.nsls2.bnl.gov/api/v1/array/full/" - } -} - -# Setup logging -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -app = typer.Typer() - - -async def get_urls_from_db(db_path, limit=None): - """Get a list of Tiled URLs from the database asynchronously""" - try: - async with aiosqlite.connect(db_path) as conn: - query = "SELECT id, tiled_url FROM vectors ORDER BY id" - if limit: - query += f" LIMIT {limit}" - - async with conn.execute(query) as cursor: - results = await cursor.fetchall() - - if not results: - logger.warning(f"No Tiled URLs found in database {db_path}") - return [] - - logger.info(f"Found {len(results)} Tiled URLs in database") - return results - except Exception as e: - logger.error(f"Error reading from database: {e}") - return [] - - -def transform_url_for_env(tiled_url, env): - """ - Transform a Tiled URL to match the specified environment format. - - Args: - tiled_url: The original Tiled URL (typically from dev environment) - env: The target environment ('dev' or 'prod') - - Returns: - str: Transformed URL for the target environment - """ - if env not in TILED_URLS: - logger.warning(f"Unknown environment '{env}', falling back to 'dev'") - env = "dev" - - # If we're staying in dev, no transformation needed - if env == "dev" and "tiled-dev.nsls2.bnl.gov" in tiled_url: - return tiled_url - - # Extract slice parameter if present - slice_param = None - if '?' in tiled_url: - slice_param = tiled_url.split('?')[1] - - # Parse the URL to extract UUID and stream path - url_without_query = tiled_url.split('?')[0] # Remove query parameters - - # Extract UUID and stream path - uuid = None - stream_path = None - - if 'array/full/' in url_without_query: - path_after_full = url_without_query.split('array/full/')[1] - parts = path_after_full.split('/') - if len(parts) >= 1: - uuid = parts[0] - if len(parts) > 1: - stream_path = '/'.join(parts[1:]) - - if not uuid or not stream_path: - logger.error(f"Could not parse Tiled URL: {tiled_url}") - return tiled_url # Return original if parsing fails - - # Transform URL based on environment - if env == "prod": - # Get image name from stream_path - parts = stream_path.split('/') - if len(parts) > 0: - image_name = parts[-1] - # Format: http://tiled.nsls2.bnl.gov/api/v1/array/full/smi/raw/{uuid}/primary/data/{image_name}?slice=... - new_url = f"{TILED_URLS[env]['url_pattern']}smi/raw/{uuid}/primary/data/{image_name}" - else: - # Fallback if we can't extract image name - logger.error(f"Could not extract image name from stream path: {stream_path}") - return tiled_url - else: - # Dev URL format: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/{uuid}/{stream_path}?slice=... - new_url = f"{TILED_URLS[env]['url_pattern']}{uuid}/{stream_path}" - - # Add slice parameter if it exists - if slice_param: - new_url = f"{new_url}?{slice_param}" - - logger.debug(f"Transformed URL: {tiled_url} -> {new_url}") - return new_url - - -def _read_image_from_tiled_url_sync(tiled_url, api_key=None): - """ - Read an image from a Tiled URL. - - Args: - tiled_url: The Tiled URL (already transformed for the appropriate environment) - api_key: API key for Tiled authentication - - Returns: - tuple: (image_data, index) - """ - try: - # Extract index from slice parameter - index = 0 # Default index - if '?' in tiled_url and 'slice=' in tiled_url: - slice_param = tiled_url.split('slice=')[1].split('&')[0] - if ':' in slice_param: - parts = slice_param.split(',')[0].split(':') - if parts[0].isdigit(): - index = int(parts[0]) - - # Parse the URL to extract base URL and path - url_without_query = tiled_url.split('?')[0] # Remove query parameters - url_parts = url_without_query.split('/api/v1/') - - if len(url_parts) != 2: - logger.error(f"Invalid Tiled URL format: {tiled_url}") - return None, 0 - - # Change array/full to metadata - base_uri = f"{url_parts[0]}/api/v1/metadata" - - # Extract dataset URI - get everything after "array/full/" - full_path = url_parts[1] - - if 'array/full/' in url_without_query: - # If the URL contains array/full, extract the part after it - path_parts = full_path.split('array/full/') - if len(path_parts) > 1: - dataset_uri = path_parts[1] - else: - dataset_uri = full_path - else: - # If URL doesn't contain array/full, use the whole path - dataset_uri = full_path - - logger.debug(f"Base URI: {base_uri}, Dataset URI: {dataset_uri}, Index: {index}") - - # Connect to the Tiled server - client = from_uri(base_uri, api_key=api_key) - - # Access the dataset - tiled_data = client[dataset_uri] - logger.debug(f"Dataset shape: {tiled_data.shape}, dtype: {tiled_data.dtype}") - - # Retrieve the image at the specified index - image = tiled_data[index] - - return image, index - - except Exception as e: - logger.error(f"Error reading from Tiled URL {tiled_url}: {e}") - return None, 0 - - -async def read_image_from_tiled_url(tiled_url, api_key=None): - """Async wrapper for _read_image_from_tiled_url_sync""" - return await asyncio.to_thread(_read_image_from_tiled_url_sync, tiled_url, api_key) - - -@app.command() -def main( - db_path: str = typer.Option(DEFAULT_DB_PATH, help="Path to the SQLite database containing Tiled URLs"), - max_frames: int = typer.Option(10000, help="Maximum number of frames to process"), - api_key: str = typer.Option(DEFAULT_API_KEY, help="API key for Tiled authentication"), - env: str = typer.Option(TILED_ENV, help="Tiled environment to use ('dev' or 'prod')") -): - """ - Run the image simulator that reads Tiled URLs from a database, fetches the images, and publishes them via ZMQ. - - Configuration can be set via environment variables: - - DB_PATH: Path to the SQLite database - - TILED_API_KEY: API key for Tiled authentication - - TILED_ENV: Environment to use ('dev' or 'prod') - - Command-line arguments override environment variables. - """ - # Log the configuration - logger.info(f"Starting DB Image Simulator with:") - logger.info(f"- Database path: {db_path}") - logger.info(f"- Max frames: {max_frames}") - logger.info(f"- API key provided: {api_key is not None}") - logger.info(f"- Tiled environment: {env}") - - async def run(): - # Check if database exists - if not os.path.exists(db_path): - logger.error(f"Database file not found: {db_path}") - return - - # Setup ZMQ socket - context = zmq.asyncio.Context() - socket = context.socket(zmq.PUB) - address = settings.tiled_poller.zmq_frame_publisher.address - logger.info(f"Binding to ZMQ address: {address}") - socket.bind(address) - - # Get URLs from database - urls = await get_urls_from_db(db_path, limit=max_frames) - if not urls: - logger.error("No URLs found in database, cannot continue") - return - - # Send start event - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - start = SASStart( - width=1679, # Default values - these will be updated from real data - height=1475, - data_type="uint32", - tiled_url=f"{env}://latent_vectors", - run_name=f"{env}_tiled_run", - run_id=str(current_time), - ) - logger.info(f"Sending start event") - await socket.send(msgpack.packb(start.model_dump())) - - # Process each URL - for db_id, tiled_url in urls: - try: - logger.info(f"Processing URL from DB record {db_id}: {tiled_url}") - - # Transform the URL for the current environment before processing - transformed_url = transform_url_for_env(tiled_url, env) - logger.info(f"Transformed URL: {transformed_url}") - - # Read image data from transformed Tiled URL - image_data, index = await read_image_from_tiled_url(transformed_url, api_key) - - if image_data is None: - logger.error(f"Failed to read image from {transformed_url}") - continue - - # Send the frame event with transformed URL - event = RawFrameEvent( - image=SerializableNumpyArrayModel(array=image_data), - frame_number=index, - tiled_url=transformed_url, - ) - logger.info(f"Sending frame {index}") - await socket.send(msgpack.packb(event.model_dump())) - - # Small delay between frames - await asyncio.sleep(0.1) - - except Exception as e: - logger.error(f"Error processing frame from {tiled_url}: {e}") - - # Send stop event - stop = SASStop(num_frames=len(urls)) - logger.info(f"Sending stop event") - await socket.send(msgpack.packb(stop.model_dump())) - logger.info(f"Complete - sent {len(urls)} frames") - - asyncio.run(run()) - - -if __name__ == "__main__": - app() \ No newline at end of file diff --git a/src/arroyosas/app/ingest_local_images.py b/src/arroyosas/app/ingest_local_images.py new file mode 100644 index 0000000..d442f9b --- /dev/null +++ b/src/arroyosas/app/ingest_local_images.py @@ -0,0 +1,213 @@ +import asyncio +import glob +import json +import logging +import os +import time +from datetime import datetime +from typing import List + +import numpy as np +import typer +from PIL import Image +from tiled.client import from_uri + +# Default settings +DEFAULT_IMAGE_FOLDER = os.getenv("DEFAULT_IMAGE_FOLDER", "./images") +TILED_URI = os.getenv("TILED_URI", "http://localhost:8000/api/v1/metadata") +TILED_API_KEY = os.getenv("TILED_API_KEY", None) +TILED_CONTAINER = os.getenv("TILED_CONTAINER", "733data") +URL_FILE = os.getenv("URL_FILE", "./tiled_url.json") + +# Setup logging +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +app = typer.Typer() + + +def load_image_files(folder_path: str) -> List[str]: + """Load image files from a folder""" + extensions = ['.jpg', '.jpeg', '.png', '.tiff', '.tif'] + + all_files = [] + for ext in extensions: + pattern = os.path.join(folder_path, f"*{ext}") + all_files.extend(glob.glob(pattern)) + pattern = os.path.join(folder_path, f"*{ext.upper()}") + all_files.extend(glob.glob(pattern)) + + all_files.sort() + + if not all_files: + logger.warning(f"No image files found in {folder_path}") + else: + logger.info(f"Found {len(all_files)} image files in {folder_path}") + + return all_files + + +def read_image_file(file_path: str) -> np.ndarray: + """Read an image file and convert it to a numpy array""" + try: + with Image.open(file_path) as img: + if img.mode != "I" and img.mode != 'L': + img = img.convert('L') + array = np.array(img, dtype=np.uint32) + else: + array = np.array(img) + return array + except Exception as e: + logger.error(f"Error reading image file {file_path}: {e}") + return np.zeros((10, 10), dtype=np.uint32) + + +def save_url_to_file(url: str, file_path: str, metadata: dict = None): + """Save Tiled URL and metadata to a file""" + data = { + "tiled_url": url, + "timestamp": datetime.now().isoformat(), + "metadata": metadata or {} + } + + try: + with open(file_path, 'w') as f: + json.dump(data, f, indent=2) + logger.info(f"Saved Tiled URL to {file_path}") + except Exception as e: + logger.error(f"Error saving URL to file: {e}") + + +async def ingest_to_tiled(client, container_name: str, image_files: List[str]) -> tuple: + """Ingest images to Tiled one by one and return the URL and metadata""" + try: + # Create container + try: + container = client.create_container(key=container_name) + logger.info(f"Created container: {container_name}") + except Exception as e: + logger.info(f"Container may already exist, trying to access: {e}") + container = client[container_name] + + # Create a timestamp-based unique identifier for this run + timestamp = int(time.time()) + run_id = f"local_images_{timestamp}" + + # Process the first image to get dimensions and create metadata + first_image = read_image_file(image_files[0]) + height, width = first_image.shape + data_type = str(first_image.dtype) + + # Initialize metadata + metadata = { + "source": "local_image_sim", + "timestamp": datetime.now().isoformat(), + "num_images": len(image_files), + "width": width, + "height": height, + "data_type": data_type, + } + + # Process each image + for i, file_path in enumerate(image_files): + if i % 10 == 0: + logger.info(f"Processing image {i+1}/{len(image_files)}") + + # Read image + image = read_image_file(file_path) + + # Create image key + image_key = f"{run_id}_{i:04d}" + + # Create image metadata + image_metadata = { + "index": i, + "filename": os.path.basename(file_path), + "timestamp": datetime.now().isoformat(), + } + + # Write image to Tiled + container.write_array( + key=image_key, + array=image, + metadata=image_metadata + ) + + # Clear memory + del image + + # Write overall metadata + container.write_array( + key=f"{run_id}_metadata", + array=np.array([0]), # Minimal array + metadata=metadata + ) + + # Construct and return the Tiled URL for the ingested data + base_url = client.uri.replace("/metadata", "") + tiled_url = f"{base_url}/array/full/{container_name}/{run_id}" + logger.info(f"Data ingested to Tiled at URL: {tiled_url}") + + # Store information for simulator to find all images + metadata["image_pattern"] = f"{run_id}_[0-9]{{4}}" + + return tiled_url, metadata + + except Exception as e: + logger.error(f"Error ingesting data to Tiled: {e}") + import traceback + logger.error(traceback.format_exc()) + raise + + +@app.command() +def main( + image_folder: str = typer.Option(DEFAULT_IMAGE_FOLDER, help="Path to folder containing image files"), + tiled_uri: str = typer.Option(TILED_URI, help="URI of the Tiled server"), + api_key: str = typer.Option(TILED_API_KEY, help="API key for Tiled authentication"), + container: str = typer.Option(TILED_CONTAINER, help="Name of the Tiled container"), + url_file: str = typer.Option(URL_FILE, help="Path to file to save the Tiled URL"), +): + """ + Read images from a local folder, ingest them to a Tiled server one by one, + and save the resulting URL to a local file for later use. + """ + logger.info(f"Starting Local Image Ingestion with:") + logger.info(f"- Image folder: {image_folder}") + logger.info(f"- Tiled URI: {tiled_uri}") + logger.info(f"- URL file: {url_file}") + + async def run(): + # Check if image folder exists + if not os.path.exists(image_folder): + logger.error(f"Image folder not found: {image_folder}") + return + + # Load image files + image_files = load_image_files(image_folder) + if not image_files: + logger.error(f"No image files found in {image_folder}") + return + + # Connect to Tiled server and ingest images + try: + client = from_uri(tiled_uri, api_key=api_key) + logger.info(f"Connected to Tiled server at {tiled_uri}") + + tiled_url, metadata = await ingest_to_tiled(client, container, image_files) + + # Save the URL to file + save_url_to_file(tiled_url, url_file, metadata) + logger.info(f"Successfully ingested {len(image_files)} images to Tiled") + logger.info(f"Tiled URL: {tiled_url}") + + except Exception as e: + logger.error(f"Failed to ingest images to Tiled: {e}") + import traceback + logger.error(traceback.format_exc()) + + asyncio.run(run()) + + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/src/arroyosas/app/real_image_sim_cli.py b/src/arroyosas/app/real_image_sim_cli.py deleted file mode 100644 index f1af498..0000000 --- a/src/arroyosas/app/real_image_sim_cli.py +++ /dev/null @@ -1,162 +0,0 @@ -import asyncio -import os -import logging -from datetime import datetime -import typer -import zmq -import zmq.asyncio -import msgpack -from tiled.client import from_uri - -from ..config import settings -from ..schemas import ( - RawFrameEvent, - SASStart, - SASStop, - SerializableNumpyArrayModel, -) - -""" -Simulates image retrieval from Tiled and sends them onto ZMQ, -taking care of pydantic messages, serialization and msgpack -""" - -# Default Tiled configuration - use the exact URI format from data_simulator.py -DATA_TILED_URI = ( - "https://tiled-demo.blueskyproject.io/api/v1/metadata/rsoxs/raw/" - "468810ed-2ff9-4e92-8ca9-dcb376d01a56/primary/data/Small Angle CCD Detector_image" -) -TILED_API_KEY = os.getenv("DATA_TILED_KEY") -if TILED_API_KEY == "": - TILED_API_KEY = None - -# Frame configuration -FRAME_WIDTH = 1024 -FRAME_HEIGHT = 1026 -DATA_TYPE = "uint32" - -# Setup logging -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -app = typer.Typer() - - -def get_num_frames(tiled_uri, tiled_api_key=None): - """Get the number of available frames from Tiled, following data_simulator.py pattern""" - client = from_uri(tiled_uri, api_key=tiled_api_key) - return client.shape[0] if hasattr(client, 'shape') and len(client.shape) > 0 else 0 - - -async def process_images_from_tiled( - socket: zmq.asyncio.Socket, - cycles: int, - frames: int, - pause: float, - tiled_uri: str, - tiled_api_key: str = None -): - """ - Process images from Tiled and send them via ZMQ - """ - try: - # Connect to Tiled server using the exact URI format - logger.info(f"Connecting to Tiled server at {tiled_uri}") - client = from_uri(tiled_uri, api_key=tiled_api_key) - - # Get total number of available frames - total_frames = get_num_frames(tiled_uri, tiled_api_key) - logger.info(f"Total frames available in Tiled: {total_frames}") - - for cycle_num in range(cycles): - # Get current time formatted as YYYY-MM-DD HH:MM:SS - current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Send SAS Start event - start = SASStart( - width=FRAME_WIDTH, - height=FRAME_HEIGHT, - data_type=DATA_TYPE, - tiled_url=tiled_uri, - run_name=f"tiled_run_{cycle_num}", - run_id=str(current_time), - ) - logger.info(f"Sending start event for cycle {cycle_num}") - await socket.send(msgpack.packb(start.model_dump())) - - # Determine number of frames to process in this cycle - frame_count = min(frames, total_frames) - if frame_count == 0: - logger.warning("No frames available in Tiled") - continue - - # Process frames - for frame_num in range(frame_count): - try: - # Get frame data directly from Tiled client - # This is the correct way to read data from a Tiled array client - # The read() method returns a numpy array with the image data - image_array = client[frame_num] - - # Create and send the frame event with the Tiled URI - event = RawFrameEvent( - image=SerializableNumpyArrayModel(array=image_array), - frame_number=frame_num, - tiled_url=f"{tiled_uri}?slice={frame_num}", - ) - logger.info(f"Sending frame {frame_num} for cycle {cycle_num}") - await socket.send(msgpack.packb(event.model_dump())) - - except Exception as e: - logger.error(f"Error processing frame {frame_num}: {e}") - - # Send stop event - stop = SASStop(num_frames=frame_count) - logger.info(f"Sending stop event for cycle {cycle_num}") - await socket.send(msgpack.packb(stop.model_dump())) - - await asyncio.sleep(pause) - logger.info(f"Cycle {cycle_num} complete - sent {frame_count} frames") - - logger.info("All cycles complete") - - except Exception as e: - logger.error(f"Error in processing images: {e}") - - -@app.command() -def main( - cycles: int = 10000, - frames: int = 50, - pause: float = 5, - tiled_uri: str = None, - api_key: str = None -): - """ - Run the image simulator that reads frames from Tiled and publishes them via ZMQ. - - Args: - cycles: Number of cycles to run - frames: Maximum number of frames per cycle - pause: Pause time between cycles in seconds - tiled_uri: URI of the Tiled server (defaults to the predefined DATA_TILED_URI) - api_key: API key for Tiled authentication (defaults to env var DATA_TILED_KEY) - """ - # Use provided values or fall back to defaults - tiled_uri = tiled_uri or DATA_TILED_URI - api_key = api_key or TILED_API_KEY - - async def run(): - context = zmq.asyncio.Context() - socket = context.socket(zmq.PUB) - address = settings.tiled_poller.publish_address - logger.info(f"Binding to ZMQ address: {address}") - socket.bind(address) - await process_images_from_tiled(socket, cycles, frames, pause, tiled_uri, api_key) - return - - asyncio.run(run()) - - -if __name__ == "__main__": - app() \ No newline at end of file diff --git a/src/arroyosas/app/unified_sim_cli.py b/src/arroyosas/app/unified_sim_cli.py new file mode 100644 index 0000000..be6b477 --- /dev/null +++ b/src/arroyosas/app/unified_sim_cli.py @@ -0,0 +1,638 @@ +import asyncio +import glob +import json +import logging +import os +import re +import aiosqlite +from datetime import datetime +from typing import List, Optional, Tuple + +import typer +import zmq +import zmq.asyncio +import msgpack +import numpy as np +from PIL import Image +from tiled.client import from_uri + +from ..config import settings +from ..schemas import ( + RawFrameEvent, + SASStart, + SASStop, + SerializableNumpyArrayModel, +) + +""" +Unified simulator combining: +- db_replay_sim_cli.py +- local_tiled_sim_cli.py +- real_image_sim_cli.py +""" + +# Setup logging +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +app = typer.Typer() + +# ============================================================================= +# FROM: db_replay_sim_cli.py +# ============================================================================= +# Get configuration from environment variables with defaults +DEFAULT_DB_PATH = os.getenv("DB_PATH", "latent_vectors.db") +TILED_ENV = os.getenv("TILED_ENV", "dev").lower() + +# Define environment-specific URLs +TILED_URLS = { + "dev": { + "url_pattern": "http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/" + }, + "prod": { + "url_pattern": "http://tiled.nsls2.bnl.gov/api/v1/array/full/" + } +} + + +async def get_urls_from_db(db_path, limit=None): + """Get a list of Tiled URLs from the database asynchronously""" + try: + async with aiosqlite.connect(db_path) as conn: + query = "SELECT id, tiled_url FROM vectors ORDER BY id" + if limit: + query += f" LIMIT {limit}" + + async with conn.execute(query) as cursor: + results = await cursor.fetchall() + + if not results: + logger.warning(f"No Tiled URLs found in database {db_path}") + return [] + + logger.info(f"Found {len(results)} Tiled URLs in database") + return results + except Exception as e: + logger.error(f"Error reading from database: {e}") + return [] + + +def transform_url_for_env(tiled_url, env): + """ + Transform a Tiled URL to match the specified environment format. + + Args: + tiled_url: The original Tiled URL (typically from dev environment) + env: The target environment ('dev' or 'prod') + + Returns: + str: Transformed URL for the target environment + """ + if env not in TILED_URLS: + logger.warning(f"Unknown environment '{env}', falling back to 'dev'") + env = "dev" + + # If we're staying in dev, no transformation needed + if env == "dev" and "tiled-dev.nsls2.bnl.gov" in tiled_url: + return tiled_url + + # Extract slice parameter if present + slice_param = None + if '?' in tiled_url: + slice_param = tiled_url.split('?')[1] + + # Parse the URL to extract UUID and stream path + url_without_query = tiled_url.split('?')[0] # Remove query parameters + + # Extract UUID and stream path + uuid = None + stream_path = None + + if 'array/full/' in url_without_query: + path_after_full = url_without_query.split('array/full/')[1] + parts = path_after_full.split('/') + if len(parts) >= 1: + uuid = parts[0] + if len(parts) > 1: + stream_path = '/'.join(parts[1:]) + + if not uuid or not stream_path: + logger.error(f"Could not parse Tiled URL: {tiled_url}") + return tiled_url # Return original if parsing fails + + # Transform URL based on environment + if env == "prod": + # Get image name from stream_path + parts = stream_path.split('/') + if len(parts) > 0: + image_name = parts[-1] + # Format: http://tiled.nsls2.bnl.gov/api/v1/array/full/smi/raw/{uuid}/primary/data/{image_name}?slice=... + new_url = f"{TILED_URLS[env]['url_pattern']}smi/raw/{uuid}/primary/data/{image_name}" + else: + # Fallback if we can't extract image name + logger.error(f"Could not extract image name from stream path: {stream_path}") + return tiled_url + else: + # Dev URL format: http://tiled-dev.nsls2.bnl.gov/api/v1/array/full/{uuid}/{stream_path}?slice=... + new_url = f"{TILED_URLS[env]['url_pattern']}{uuid}/{stream_path}" + + # Add slice parameter if it exists + if slice_param: + new_url = f"{new_url}?{slice_param}" + + logger.debug(f"Transformed URL: {tiled_url} -> {new_url}") + return new_url + + +def _read_image_from_tiled_url_sync(tiled_url, api_key=None): + """ + Read an image from a Tiled URL. + + Args: + tiled_url: The Tiled URL (already transformed for the appropriate environment) + api_key: API key for Tiled authentication + + Returns: + tuple: (image_data, index) + """ + try: + # Extract index from slice parameter + index = 0 # Default index + if '?' in tiled_url and 'slice=' in tiled_url: + slice_param = tiled_url.split('slice=')[1].split('&')[0] + if ':' in slice_param: + parts = slice_param.split(',')[0].split(':') + if parts[0].isdigit(): + index = int(parts[0]) + + # Parse the URL to extract base URL and path + url_without_query = tiled_url.split('?')[0] # Remove query parameters + url_parts = url_without_query.split('/api/v1/') + + if len(url_parts) != 2: + logger.error(f"Invalid Tiled URL format: {tiled_url}") + return None, 0 + + # Change array/full to metadata + base_uri = f"{url_parts[0]}/api/v1/metadata" + + # Extract dataset URI - get everything after "array/full/" + full_path = url_parts[1] + + if 'array/full/' in url_without_query: + # If the URL contains array/full, extract the part after it + path_parts = full_path.split('array/full/') + if len(path_parts) > 1: + dataset_uri = path_parts[1] + else: + dataset_uri = full_path + else: + # If URL doesn't contain array/full, use the whole path + dataset_uri = full_path + + logger.debug(f"Base URI: {base_uri}, Dataset URI: {dataset_uri}, Index: {index}") + + # Connect to the Tiled server + client = from_uri(base_uri, api_key=api_key) + + # Access the dataset + tiled_data = client[dataset_uri] + logger.debug(f"Dataset shape: {tiled_data.shape}, dtype: {tiled_data.dtype}") + + # Retrieve the image at the specified index + image = tiled_data[index] + + return image, index + + except Exception as e: + logger.error(f"Error reading from Tiled URL {tiled_url}: {e}") + return None, 0 + + +async def read_image_from_tiled_url(tiled_url, api_key=None): + """Async wrapper for _read_image_from_tiled_url_sync""" + return await asyncio.to_thread(_read_image_from_tiled_url_sync, tiled_url, api_key) + + +# ============================================================================= +# FROM: local_tiled_sim_cli.py +# ============================================================================= + +# Default settings +URL_FILE = os.getenv("URL_FILE", "./tiled_url.json") + + +def load_url_from_file(file_path: str) -> tuple: + """Load Tiled URL and metadata from a file""" + try: + if not os.path.exists(file_path): + logger.error(f"URL file not found: {file_path}") + return None, None + + with open(file_path, 'r') as f: + data = json.load(f) + + tiled_url = data.get("tiled_url") + metadata = data.get("metadata", {}) + + if not tiled_url: + logger.error(f"No Tiled URL found in {file_path}") + return None, None + + logger.info(f"Loaded Tiled URL from {file_path}: {tiled_url}") + return tiled_url, metadata + + except Exception as e: + logger.error(f"Error loading URL from file: {e}") + return None, None + + +async def fetch_image_from_tiled(client, container_name: str, image_key: str) -> np.ndarray: + """Fetch a single image from Tiled as a numpy array""" + try: + # Get the container + container = client[container_name] + + # Get the image data + array_client = container[image_key] + + # Convert to numpy array + data = array_client.read() + + # Ensure it's a numpy array + if not isinstance(data, np.ndarray): + data = np.array(data, dtype=np.uint32) + + return data + except Exception as e: + logger.error(f"Error fetching image {image_key}: {e}") + # Return a small dummy array in case of error + return np.zeros((10, 10), dtype=np.uint32) + + +async def get_matching_keys(client, container_name: str, pattern: str) -> List[str]: + """Get all keys in a container matching a pattern""" + try: + # Get the container + container = client[container_name] + + # Get all keys + all_keys = list(container.keys()) + logger.info(f"Found {len(all_keys)} total keys in container") + + # Create regex from pattern - properly escape special characters + regex = pattern + for char in '[]{}.': + regex = regex.replace(char, '\\' + char) + # Replace the digit pattern with actual regex pattern + regex = regex.replace('\\[0-9\\]\\{4\\}', r'\d{4}') + regex_pattern = re.compile(regex) + + logger.info(f"Using regex pattern: {regex}") + + # Filter keys by regex + matching_keys = sorted([k for k in all_keys if regex_pattern.match(k)]) + logger.info(f"Found {len(matching_keys)} matching image keys") + + return matching_keys + except Exception as e: + logger.error(f"Error getting matching keys: {e}") + return [] + + +# ============================================================================= +# FROM: real_image_sim_cli.py +# ============================================================================= + +# Default Tiled configuration - use the exact URI format from data_simulator.py +DATA_TILED_URI = ( + "https://tiled-demo.blueskyproject.io/api/v1/metadata/rsoxs/raw/" + "468810ed-2ff9-4e92-8ca9-dcb376d01a56/primary/data/Small Angle CCD Detector_image" +) + +# Frame configuration +FRAME_WIDTH = 1024 +FRAME_HEIGHT = 1026 +DATA_TYPE = "uint32" + + +def get_num_frames(tiled_uri, tiled_api_key=None): + """Get the number of available frames from Tiled, following data_simulator.py pattern""" + client = from_uri(tiled_uri, api_key=tiled_api_key) + return client.shape[0] if hasattr(client, 'shape') and len(client.shape) > 0 else 0 + + +async def process_images_from_tiled( + socket: zmq.asyncio.Socket, + cycles: int, + frames: int, + pause: float, + tiled_uri: str, + tiled_api_key: str = None +): + """ + Process images from Tiled and send them via ZMQ + """ + try: + # Connect to Tiled server using the exact URI format + logger.info(f"Connecting to Tiled server at {tiled_uri}") + client = from_uri(tiled_uri, api_key=tiled_api_key) + + # Get total number of available frames + total_frames = get_num_frames(tiled_uri, tiled_api_key) + logger.info(f"Total frames available in Tiled: {total_frames}") + + for cycle_num in range(cycles): + # Get current time formatted as YYYY-MM-DD HH:MM:SS + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Send SAS Start event + start = SASStart( + width=FRAME_WIDTH, + height=FRAME_HEIGHT, + data_type=DATA_TYPE, + tiled_url=tiled_uri, + run_name=f"tiled_run_{cycle_num}", + run_id=str(current_time), + ) + logger.info(f"Sending start event for cycle {cycle_num}") + await socket.send(msgpack.packb(start.model_dump())) + + # Determine number of frames to process in this cycle + frame_count = min(frames, total_frames) + if frame_count == 0: + logger.warning("No frames available in Tiled") + continue + + # Process frames + for frame_num in range(frame_count): + try: + # Get frame data directly from Tiled client + # This is the correct way to read data from a Tiled array client + # The read() method returns a numpy array with the image data + image_array = client[frame_num] + + # Create and send the frame event with the Tiled URI + event = RawFrameEvent( + image=SerializableNumpyArrayModel(array=image_array), + frame_number=frame_num, + tiled_url=f"{tiled_uri}?slice={frame_num}", + ) + logger.info(f"Sending frame {frame_num} for cycle {cycle_num}") + await socket.send(msgpack.packb(event.model_dump())) + + except Exception as e: + logger.error(f"Error processing frame {frame_num}: {e}") + + # Send stop event + stop = SASStop(num_frames=frame_count) + logger.info(f"Sending stop event for cycle {cycle_num}") + await socket.send(msgpack.packb(stop.model_dump())) + + await asyncio.sleep(pause) + logger.info(f"Cycle {cycle_num} complete - sent {frame_count} frames") + + logger.info("All cycles complete") + + except Exception as e: + logger.error(f"Error in processing images: {e}") + + +# ============================================================================= +# UNIFIED CLI +# ============================================================================= + +@app.command() +def main( + mode: str = typer.Option("direct", help="Simulator mode: 'direct', 'db_replay', or 'local_tiled'"), + # db_replay parameters + db_path: str = typer.Option(DEFAULT_DB_PATH, help="[db_replay] Path to the SQLite database"), + max_frames: int = typer.Option(10000, help="[db_replay] Maximum number of frames to process"), + env: str = typer.Option(TILED_ENV, help="[db_replay] Tiled environment ('dev' or 'prod')"), + db_replay_api_key: str = typer.Option(None, help="[db_replay] API key for Tiled authentication"), + # local_tiled parameters + url_file: str = typer.Option(URL_FILE, help="[local_tiled] Path to file containing the Tiled URL"), + cycles: int = typer.Option(1, help="[local_tiled/direct] Number of cycles to run"), + pause: float = typer.Option(0.1, help="[local_tiled/direct] Pause time between frames"), + cycle_pause: float = typer.Option(5.0, help="[local_tiled/direct] Pause time between cycles"), + local_tiled_api_key: str = typer.Option(None, help="[local_tiled] API key for Tiled authentication"), + # direct parameters + frames: int = typer.Option(50, help="[direct] Maximum number of frames per cycle"), + tiled_uri: str = typer.Option(DATA_TILED_URI, help="[direct] URI of the Tiled server"), +): + """ + Unified simulator supporting three modes: + - direct: Direct connection to Tiled URI (from real_image_sim_cli.py) + - db_replay: Replay from SQLite database (from db_replay_sim_cli.py) + - local_tiled: Read from JSON file (from local_tiled_sim_cli.py) + """ + logger.info(f"Starting Unified Simulator in '{mode}' mode") + + async def run(): + # Setup ZMQ socket + context = zmq.asyncio.Context() + socket = context.socket(zmq.PUB) + address = settings.tiled_poller.zmq_frame_publisher.address + logger.info(f"Binding to ZMQ address: {address}") + socket.bind(address) + + if mode == "db_replay": + # FROM: db_replay_sim_cli.py main() + # Use db_replay_api_key or fallback to env var + api_key = db_replay_api_key or os.getenv("TILED_LIVE_API_KEY") + + logger.info(f"DB replay mode with:") + logger.info(f"- Database path: {db_path}") + logger.info(f"- Max frames: {max_frames}") + logger.info(f"- Tiled environment: {env}") + logger.info(f"- API key provided: {api_key is not None}") + + if not os.path.exists(db_path): + logger.error(f"Database file not found: {db_path}") + return + + urls = await get_urls_from_db(db_path, limit=max_frames) + if not urls: + logger.error("No URLs found in database, cannot continue") + return + + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + start = SASStart( + width=1679, + height=1475, + data_type="uint32", + tiled_url=f"{env}://latent_vectors", + run_name=f"{env}_tiled_run", + run_id=str(current_time), + ) + logger.info(f"Sending start event") + await socket.send(msgpack.packb(start.model_dump())) + + for db_id, tiled_url in urls: + try: + logger.info(f"Processing URL from DB record {db_id}: {tiled_url}") + + transformed_url = transform_url_for_env(tiled_url, env) + logger.info(f"Transformed URL: {transformed_url}") + + image_data, index = await read_image_from_tiled_url(transformed_url, api_key) + + if image_data is None: + logger.error(f"Failed to read image from {transformed_url}") + continue + + event = RawFrameEvent( + image=SerializableNumpyArrayModel(array=image_data), + frame_number=index, + tiled_url=transformed_url, + ) + logger.info(f"Sending frame {index}") + await socket.send(msgpack.packb(event.model_dump())) + + await asyncio.sleep(0.1) + + except Exception as e: + logger.error(f"Error processing frame from {tiled_url}: {e}") + + stop = SASStop(num_frames=len(urls)) + logger.info(f"Sending stop event") + await socket.send(msgpack.packb(stop.model_dump())) + logger.info(f"Complete - sent {len(urls)} frames") + + elif mode == "local_tiled": + # FROM: local_tiled_sim_cli.py main() + # Use local_tiled_api_key or fallback to env var + api_key = local_tiled_api_key or os.getenv("TILED_LIVE_API_KEY") + + logger.info(f"Local tiled mode with:") + logger.info(f"- URL file: {url_file}") + logger.info(f"- Cycles: {cycles}") + logger.info(f"- API key provided: {api_key is not None}") + + tiled_url, metadata = load_url_from_file(url_file) + if not tiled_url: + return + + url_without_query = tiled_url.split('?')[0] + url_parts = url_without_query.split('/api/v1/') + + if len(url_parts) != 2: + logger.error(f"Invalid Tiled URL format: {tiled_url}") + return + + base_uri = f"{url_parts[0]}/api/v1/metadata" + full_path = url_parts[1] + + if 'array/full/' in url_without_query: + path_parts = full_path.split('array/full/') + if len(path_parts) > 1: + dataset_uri = path_parts[1] + else: + dataset_uri = full_path + else: + dataset_uri = full_path + + logger.info(f"Connecting to Tiled server at {base_uri}") + logger.info(f"Dataset URI: {dataset_uri}") + + client = from_uri(base_uri, api_key=api_key) + + parts = dataset_uri.split('/') + if len(parts) != 2: + logger.error(f"Invalid dataset URI format: {dataset_uri}") + return + + container_name = parts[0] + run_id_base = parts[1] + + if metadata and "image_pattern" in metadata: + pattern = metadata["image_pattern"] + matching_keys = await get_matching_keys(client, container_name, pattern) + + if not matching_keys: + logger.error(f"No matching keys found for pattern {pattern}") + return + + num_images = len(matching_keys) + else: + num_images = metadata.get("num_images", 0) + if num_images <= 0: + logger.error("No image count found in metadata") + return + + matching_keys = [f"{run_id_base}_{i:04d}" for i in range(num_images)] + + logger.info(f"Found {len(matching_keys)} image keys") + + width = metadata.get("width") + height = metadata.get("height") + data_type = metadata.get("data_type") + + if not all([width, height, data_type]): + logger.error("Missing image dimensions in metadata") + return + + for cycle_num in range(cycles): + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + start = SASStart( + width=width, + height=height, + data_type=data_type, + tiled_url=tiled_url, + run_name=f"local_tiled_run_{cycle_num}", + run_id=str(current_time), + ) + logger.info(f"Sending start event for cycle {cycle_num}") + await socket.send(msgpack.packb(start.model_dump())) + + for frame_num, image_key in enumerate(matching_keys): + try: + image_array = await fetch_image_from_tiled(client, container_name, image_key) + + frame_url = f"{url_parts[0]}/api/v1/array/full/{container_name}/{image_key}?slice=0:1,0:1679,0:1475" + + event = RawFrameEvent( + image=SerializableNumpyArrayModel(array=image_array), + frame_number=0, + tiled_url=frame_url, + ) + logger.info(f"Sending image {image_key} for cycle {cycle_num}") + await socket.send(msgpack.packb(event.model_dump())) + + except Exception as e: + logger.error(f"Error sending image {image_key}: {e}") + import traceback + logger.error(traceback.format_exc()) + continue + + await asyncio.sleep(pause) + + stop = SASStop(num_frames=len(matching_keys)) + logger.info(f"Sending stop event for cycle {cycle_num}") + await socket.send(msgpack.packb(stop.model_dump())) + + if cycle_num < cycles - 1: + logger.info(f"Cycle {cycle_num} complete - pausing for {cycle_pause}s") + await asyncio.sleep(cycle_pause) + + logger.info(f"All {cycles} cycles complete") + + elif mode == "direct": + # FROM: real_image_sim_cli.py main() + logger.info(f"Direct mode with:") + logger.info(f"- Tiled URI: {tiled_uri}") + logger.info(f"- Cycles: {cycles}") + logger.info(f"- Frames: {frames}") + logger.info(f"- API key is not needed for public tiled dataset.") + + await process_images_from_tiled(socket, cycles, frames, pause, tiled_uri, tiled_api_key=None) + + else: + logger.error(f"Unknown mode: {mode}. Use 'direct', 'db_replay', or 'local_tiled'") + + asyncio.run(run()) + + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/src/arroyosas/app/xps_websocket_simulator.py b/src/arroyosas/app/xps_websocket_simulator.py new file mode 100644 index 0000000..ad1d4b4 --- /dev/null +++ b/src/arroyosas/app/xps_websocket_simulator.py @@ -0,0 +1,426 @@ +""" +XPS WebSocket Simulator for arroyosas package + +Loads XPS data from combined_all_bin_averages.npy and streams via WebSocket +in the format expected by the XPS processing system (tr_ap_xps format). + +Data structure: Dictionary with keys (bin_num, shot_num) -> 2D numpy array +Example: combined_averages[(5, 70)] = heatmap array + +Usage: + python -m arroyosas.app.xps_websocket_simulator --data-file /data/xps_test/combined_all_bin_averages.npy +""" + +import asyncio +import json +import logging +import os +import uuid +from pathlib import Path +from typing import List, Tuple, Dict + +import msgpack +import numpy as np +import typer +import websockets + +# Setup logging +logger = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, + format='%(levelname)s: (%(name)s) %(message)s' +) + +app = typer.Typer() + + +def load_xps_data(data_file: str) -> Tuple[np.ndarray, Dict[int, Tuple[int, int]]]: + """ + Load XPS data from combined_all_bin_averages.npy + + Args: + data_file: Path to combined_all_bin_averages.npy + + Returns: + Tuple of: + - 3D numpy array (total_entries, height, width) + - Index map: {index: (bin_num, shot_num)} + """ + try: + # Load the combined dictionary + combined_averages = np.load(data_file, allow_pickle=True).item() + + if not isinstance(combined_averages, dict): + raise ValueError(f"Expected dict, got {type(combined_averages)}") + + # Get all keys, sorted by bin then shot number + all_keys = sorted(combined_averages.keys()) + total_entries = len(all_keys) + + if total_entries == 0: + raise ValueError("No data found in file") + + # Get shape from first entry + first_key = all_keys[0] + single_image_shape = combined_averages[first_key].shape + + logger.info(f"Data structure:") + logger.info(f" Total entries: {total_entries}") + logger.info(f" Image shape: {single_image_shape}") + + # Create 3D array + all_data = np.zeros((total_entries, *single_image_shape), dtype=np.float32) + + # Create index map + index_map = {} + + # Fill array and map + for i, key in enumerate(all_keys): + all_data[i] = combined_averages[key] + index_map[i] = key # key is (bin_num, shot_num) + + # Log bin distribution + bins = set(bin_num for bin_num, _ in all_keys) + logger.info(f" Number of bins: {len(bins)}") + logger.info(f" Bin numbers: {sorted(bins)}") + + # Count shots per bin + from collections import defaultdict + shots_per_bin = defaultdict(int) + for bin_num, _ in all_keys: + shots_per_bin[bin_num] += 1 + + for bin_num in sorted(shots_per_bin.keys()): + logger.info(f" Bin {bin_num}: {shots_per_bin[bin_num]} shots") + + return all_data, index_map + + except Exception as e: + logger.error(f"Error loading data file: {e}") + raise + + +def prepare_start_message() -> str: + """ + Prepare XPS start message in JSON format. + Matches the tr_ap_xps XPSStart schema. + """ + scan_uuid = str(uuid.uuid4()) + start_message = { + "msg_type": "start", + "scan_name": f"temp name {scan_uuid}", + "F_Trigger": 13, + "F_Un-Trigger": 38, + "F_Dead": 45, + "F_Reset": 46, + "CCD_nx": 1392, + "CCD_ny": 1040, + "Pass Energy": 200, + "Center Energy": 3308, + "Offset Energy": -0.837, + "Lens Mode": "X6-26Mar2022-test", + "Rectangle": { + "Left": 148, + "Top": 385, + "Right": 1279, + "Bottom": 654, + "Rotation": 0 + }, + "data_type": "U8", + "dt": 0.0820741786426572, + "Photon Energy": 3999.99740398402, + "Binding Energy": 90, + "File Ver": "1.0.0" + } + return json.dumps(start_message) + + +def prepare_stop_message() -> str: + """ + Prepare XPS stop message in JSON format. + Matches the tr_ap_xps XPSStop schema. + """ + stop_message = { + "msg_type": "stop", + "Num Frames": 0 + } + return json.dumps(stop_message) + + +def convert_to_uint8(image: np.ndarray) -> bytes: + """ + Convert an image to uint8 with logarithmic scaling. + Matches the tr_ap_xps/websockets.py convert_to_uint8 function. + """ + if image.size == 0: + return b'' + + # Normalize to [0, 1] + image_normalized = (image - image.min()) / (image.max() - image.min() + 1e-10) + + # Apply logarithmic stretch + log_stretched = np.log1p(image_normalized) + + # Normalize again + log_stretched_normalized = (log_stretched - log_stretched.min()) / ( + log_stretched.max() - log_stretched.min() + 1e-10 + ) + + # Convert to uint8 + image_uint8 = (log_stretched_normalized * 255).astype(np.uint8) + return image_uint8.tobytes() + + +def prepare_xps_message( + shot_num: int, + bin_num: int, + shot_mean: np.ndarray +) -> bytes: + """ + Prepare XPS message matching the tr_ap_xps format. + + Args: + shot_num: Shot number within current bin + bin_num: Current bin number + shot_mean: 2D numpy array with heatmap data + + Returns: msgpack binary message + """ + # Ensure 2D shape + if shot_mean.ndim != 2: + logger.warning(f"Expected 2D array, got shape {shot_mean.shape}") + if shot_mean.ndim == 1: + shot_mean = shot_mean.reshape(1, -1) + else: + shot_mean = shot_mean.reshape(shot_mean.shape[0], -1) + + height, width = shot_mean.shape + + # msgpack with image data (tr_ap_xps format) + msgpack_data = { + "shot_num": shot_num, + "shot_mean": convert_to_uint8(shot_mean), + "shot_recent": convert_to_uint8(shot_mean), + "shot_std": convert_to_uint8(np.zeros_like(shot_mean)), + "width": height, # Note: swapped in real system + "height": width, + "raw": convert_to_uint8(shot_mean), + "vfft": convert_to_uint8(shot_mean), + "ifft": convert_to_uint8(shot_mean), + "fitted": json.dumps([]) + } + + return msgpack.packb(msgpack_data) + + +async def send_xps_data( + websocket, + shot_num: int, + bin_num: int, + shot_mean: np.ndarray +): + """Send XPS data through WebSocket""" + try: + msgpack_msg = prepare_xps_message(shot_num, bin_num, shot_mean) + msg_size = len(msgpack_msg) + + await websocket.send(msgpack_msg) + logger.info(f"βœ“ Sent bin {bin_num:2d}, shot {shot_num:2d} | shape={shot_mean.shape} | size={msg_size:,} bytes") + + except Exception as e: + logger.error(f"βœ— Error sending bin {bin_num}, shot {shot_num}: {e}") + + +async def websocket_handler( + websocket, + all_data: np.ndarray, + index_map: Dict[int, Tuple[int, int]], + cycles: int, + pause: float, + bins_to_send: List[int] = None +): + """ + Handle WebSocket connection and send XPS data. + + Args: + websocket: WebSocket connection + all_data: 3D array (entries, height, width) + index_map: Maps index -> (bin_num, shot_num) + cycles: Number of times to repeat the data + pause: Pause between shots in seconds + bins_to_send: List of bin numbers to send (None = all bins) + """ + client_info = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}" + logger.info("=" * 70) + logger.info(f"πŸ“Œ New WebSocket connection from {client_info}") + logger.info(f"πŸ“ Path: {websocket.request.path}") + logger.info("=" * 70) + + # Check path + if websocket.request.path != "/simImages": + logger.warning(f"⚠️ Unexpected path: {websocket.request.path}, expected /simImages") + + try: + for cycle in range(cycles): + logger.info("") + logger.info("=" * 70) + logger.info(f"πŸ”„ Cycle {cycle + 1}/{cycles}") + logger.info("=" * 70) + + # Group indices by bin + from collections import defaultdict + bin_indices = defaultdict(list) + for idx, (bin_num, shot_num) in index_map.items(): + if bins_to_send is None or bin_num in bins_to_send: + bin_indices[bin_num].append((idx, shot_num)) + + # Send data bin by bin - each bin gets its own START/STOP + for bin_num in sorted(bin_indices.keys()): + logger.info(f"") + logger.info(f"πŸ“¦ Starting bin {bin_num}") + + # Send start message for this bin + start_msg = prepare_start_message() + # Extract UUID from the start message for logging + start_data = json.loads(start_msg) + scan_uuid = start_data['scan_name'].replace('temp name ', '').strip() + await websocket.send(start_msg) + logger.info(f"πŸ“€ Sent START message for bin {bin_num} with UUID: {scan_uuid}") + await asyncio.sleep(0.1) + + logger.info(f" Total shots in bin: {len(bin_indices[bin_num])}") + + # Sort shots within bin + shots = sorted(bin_indices[bin_num], key=lambda x: x[1]) + + # Track timing + import time + bin_start = time.time() + + # Send all shots for this bin + for idx, original_shot_num in shots: + shot_num = original_shot_num + shot_mean = all_data[idx] + + await send_xps_data(websocket, shot_num, bin_num, shot_mean) + await asyncio.sleep(pause) + + # Send stop message for this bin + stop_msg = prepare_stop_message() + await websocket.send(stop_msg) + logger.info(f"πŸ“€ Sent STOP message for bin {bin_num}") + + bin_duration = time.time() - bin_start + logger.info(f"βœ… Completed bin {bin_num} | {len(shots)} shots in {bin_duration:.2f}s") + + if cycle < cycles - 1: + logger.info("") + logger.info(f"⏸️ Cycle {cycle + 1} complete, pausing before next cycle...") + await asyncio.sleep(1.0) + + logger.info("") + logger.info("=" * 70) + logger.info(f"πŸŽ‰ All {cycles} cycles complete!") + logger.info("=" * 70) + await asyncio.sleep(2.0) + + except websockets.exceptions.ConnectionClosed: + logger.info(f"πŸ”Œ Client {client_info} disconnected") + except Exception as e: + logger.error(f"❌ Error in handler: {e}") + import traceback + logger.error(traceback.format_exc()) + + +async def run_server( + host: str, + port: int, + data_file: str, + cycles: int, + pause: float, + bins: str = None +): + """Run the WebSocket server.""" + # Load data + logger.info(f"Loading data from {data_file}...") + all_data, index_map = load_xps_data(data_file) + + # Parse bins to send + bins_to_send = None + if bins: + try: + bins_to_send = [int(b.strip()) for b in bins.split(',')] + logger.info(f"Will send only bins: {bins_to_send}") + except: + logger.error(f"Invalid bins format: {bins}. Using all bins.") + + logger.info(f"WebSocket server: ws://{host}:{port}/simImages") + + async def handler(websocket): + await websocket_handler( + websocket, + all_data, + index_map, + cycles, + pause, + bins_to_send + ) + + async with websockets.serve(handler, host, port): + logger.info("Server running...") + await asyncio.Future() + + +@app.command() +def main( + host: str = typer.Option("0.0.0.0", help="Host to bind"), + port: int = typer.Option(8001, help="Port to bind"), + data_file: str = typer.Option( + "/data/xps_test/combined_all_bin_averages.npy", + help="Path to combined_all_bin_averages.npy" + ), + cycles: int = typer.Option(1, help="Number of cycles to repeat all data"), + pause: float = typer.Option(0.1, help="Pause between shots (seconds)"), + bins: str = typer.Option(None, help="Comma-separated bin numbers to send (e.g. '0,1,5')"), +): + """ + XPS WebSocket Simulator + + Loads data from combined_all_bin_averages.npy and streams via WebSocket. + Compatible with tr_ap_xps message format. + + Data structure: + - Dictionary with keys (bin_num, shot_num) -> 2D numpy array + - shot_num resets for each bin + - Sends bins in order, with all shots for each bin + + Examples: + # Send all bins + python -m arroyosas.app.xps_websocket_simulator --data-file data.npy + + # Send only bins 0, 1, and 5 + python -m arroyosas.app.xps_websocket_simulator --bins "0,1,5" + + # Repeat 10 times with 0.05s pause + python -m arroyosas.app.xps_websocket_simulator --cycles 10 --pause 0.05 + """ + logger.info("=" * 60) + logger.info("XPS WebSocket Simulator (Bin Structure)") + logger.info("=" * 60) + logger.info(f"Host: {host}:{port}") + logger.info(f"Data file: {data_file}") + logger.info(f"Cycles: {cycles}, Pause: {pause}s") + if bins: + logger.info(f"Bins filter: {bins}") + logger.info("=" * 60) + + if not os.path.exists(data_file): + logger.error(f"Data file not found: {data_file}") + logger.info("Please provide the path to combined_all_bin_averages.npy") + return + + asyncio.run(run_server(host, port, data_file, cycles, pause, bins)) + + +if __name__ == "__main__": + app() \ No newline at end of file