diff options
Diffstat (limited to 'main.py')
| -rw-r--r-- | main.py | 514 |
1 files changed, 514 insertions, 0 deletions
@@ -0,0 +1,514 @@ +import sys +import numpy as np +import cv2 +import mapbox_earcut as earcut +import pyray as pr +import tifffile +import topology +import time +import tkinter as tk +from tkinter import filedialog + +def point_dist(p1, p2): + return (p1.x - p2.x)**2 + (p1.y - p2.y)**2 + +def load_image_data(img_path): + print(f"Loading image {img_path}...") + if img_path.lower().endswith('.tif') or img_path.lower().endswith('.tiff'): + img_data = tifffile.imread(img_path) + if len(img_data.shape) == 3 and img_data.shape[0] in [1, 3, 4]: + img_data = np.transpose(img_data, (1, 2, 0)) + else: + img_data = cv2.imread(img_path) + if img_data is not None: + img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB) + return img_data + +def load_segmentation_data(seg_path, height, width): + print(f"Loading segmentation {seg_path}...") + if seg_path.lower().endswith('.npy'): + seg_data = np.load(seg_path, allow_pickle=True) + if seg_data.shape == (): # It's a dict + seg_data = seg_data.item()['masks'] + else: # bin file + seg_data = np.fromfile(seg_path, dtype=np.uint16) + seg_data = seg_data.reshape((height, width)) + return seg_data + +def create_texture_from_numpy(img_data): + if len(img_data.shape) == 2: # grayscale + img_rgb = cv2.cvtColor(img_data, cv2.COLOR_GRAY2RGB) + else: + # TIF might have alpha or weird channels, assume first 3 are RGB for now + img_rgb = img_data[:, :, :3] + if img_rgb.dtype != np.uint8: # normalize to 8-bit if it's 16-bit tiff + img_rgb = cv2.normalize(img_rgb, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) + + cv2.imwrite("tmp_bg.png", cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)) + return pr.load_texture("tmp_bg.png") + +def open_file_dialog(): + root = tk.Tk() + root.withdraw() + # Attempt to make it topmost for some window managers + root.attributes('-topmost', True) + file_path = filedialog.askopenfilename( + title="Select Image or Segmentation File", + filetypes=[ + ("Image/Seg files", "*.png *.jpg *.jpeg *.tif *.tiff *.npy *.bin"), + ("Image files", "*.png *.jpg *.jpeg *.tif *.tiff"), + ("Segmentation files", "*.npy *.bin"), + ("All files", "*.*") + ] + ) + root.destroy() + return file_path + +def save_file_dialog(): + root = tk.Tk() + root.withdraw() + root.attributes('-topmost', True) + file_path = filedialog.asksaveasfilename( + title="Save Segmentation As", + defaultextension=".npy", + filetypes=[ + ("NPY files", "*.npy"), + ("Binary files", "*.bin"), + ("All files", "*.*") + ] + ) + root.destroy() + return file_path + +def main(): + img_path = sys.argv[1] if len(sys.argv) > 1 else None + seg_path = sys.argv[2] if len(sys.argv) > 2 else None + + img_data = None + vertices = [] + regions = [] + width, height = 800, 600 # Default window size if no image + + if img_path: + img_data = load_image_data(img_path) + if img_data is not None: + height, width = img_data.shape[:2] + if seg_path: + seg_data = load_segmentation_data(seg_path, height, width) + print("Extracting shared boundary vertices...") + vertices, regions = topology.extract_boundaries(seg_data) + else: + print(f"Failed to load image: {img_path}") + img_path = None + + # Raylib Initialization + print("Initialize Window...") + scale_factor = 1.0 + if img_data is not None: + max_dim = 1000 + if max(width, height) > max_dim: + scale_factor = max_dim / max(width, height) + window_w = int(width * scale_factor) + window_h = int(height * scale_factor) + else: + window_w, window_h = 800, 600 + + pr.set_config_flags(pr.FLAG_WINDOW_RESIZABLE) + pr.init_window(window_w, window_h, "Segmentation Editor") + pr.set_target_fps(60) + + bg_texture = None + if img_data is not None: + bg_texture = create_texture_from_numpy(img_data) + + # State + dragging_vertex_idx = -1 + hovered_vertex_idx = -1 + last_click_time = 0.0 + selected_region_idx = -1 + selection_time = 0.0 + + empty_selection_origin = None + empty_selection_time = 0.0 + + # Selection radius (scaled by zoom later implicitly by world coordinates) + PICK_RADIUS = 10.0 + + # Camera for panning/zooming + camera = pr.Camera2D() + camera.target = pr.Vector2(0, 0) + camera.offset = pr.Vector2(0, 0) + camera.rotation = 0.0 + camera.zoom = scale_factor + + while not pr.window_should_close(): + + if pr.window_should_close(): break + + def handle_file_load(path): + nonlocal img_data, height, width, bg_texture, img_path, vertices, regions, seg_path + low_path = path.lower() + if low_path.endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')): + # New Image + new_img = load_image_data(path) + if new_img is not None: + img_data = new_img + height, width = img_data.shape[:2] + if bg_texture: pr.unload_texture(bg_texture) + bg_texture = create_texture_from_numpy(img_data) + img_path = path + # Reset + vertices = [] + regions = [] + print(f"Loaded image: {path}") + elif low_path.endswith(('.npy', '.bin')): + if img_data is not None: + seg_data = load_segmentation_data(path, height, width) + print("Extracting shared boundary vertices...") + new_vertices, new_regions = topology.extract_boundaries(seg_data) + vertices, regions = new_vertices, new_regions + seg_path = path + print(f"Loaded segmentation: {path}") + else: + print("Please load an image first!") + + def save_segmentation(out_path): + nonlocal vertices, regions, width, height, seg_path + print(f"Saving modified mask to {out_path}...") + new_mask = topology.reconstruct_mask(vertices, regions, width, height) + + # Read original dict to keep image parity + if out_path.endswith('.npy'): + if seg_path and seg_path.endswith('.npy'): + orig_data = np.load(seg_path, allow_pickle=True) + if orig_data.shape == (): # It's a dict + new_dict = orig_data.item().copy() + new_dict['masks'] = new_mask + np.save(out_path, new_dict) + print(f"Saved merged dict to {out_path}") + return + + np.save(out_path, new_mask) + print(f"Saved array to {out_path}") + elif out_path.endswith('.bin'): + with open(out_path, "wb") as f: + f.write(new_mask.tobytes()) + print(f"Saved binary to {out_path}") + else: + # Default to NPY + np.save(out_path if out_path.endswith('.npy') else out_path + ".npy", new_mask) + print(f"Saved to {out_path}") + + # File Picker Shortcut (Ctrl+O) + if pr.is_key_down(pr.KEY_LEFT_CONTROL) and pr.is_key_pressed(pr.KEY_O): + picked_path = open_file_dialog() + if picked_path: + handle_file_load(picked_path) + + # File Drag and Drop Handling + if pr.is_file_dropped(): + dropped_files = pr.load_dropped_files() + for i in range(dropped_files.count): + # FilePathList.paths is a char**, we need to convert to python string + dropped_path = pr.ffi.string(dropped_files.paths[i]).decode('utf-8') + handle_file_load(dropped_path) + pr.unload_dropped_files(dropped_files) + + mouse_pos = pr.get_mouse_position() + world_mouse_pos = pr.get_screen_to_world_2d(mouse_pos, camera) + + # Find hovered vertex globally + hovered_vertex_idx = -1 + min_dist = float('inf') + # scale picking radius inversely with zoom so the apparent hit circle size remains constant + dynamic_pick_radius_sq = (PICK_RADIUS / camera.zoom)**2 + + # Optimization: in a huge graph, you'd use a generic spatial index here (e.g. quadtree) + # For ~10k vertices array scan is usually fine in Python for 60fps + for i, (vx, vy) in enumerate(vertices): + v_pos = pr.Vector2(vx, vy) + d = point_dist(world_mouse_pos, v_pos) + if d < dynamic_pick_radius_sq and d < min_dist: + min_dist = d + hovered_vertex_idx = i + + # Handle Mouse Input + if pr.is_mouse_button_pressed(pr.MOUSE_BUTTON_LEFT): + if hovered_vertex_idx != -1: + dragging_vertex_idx = hovered_vertex_idx + else: + # Check for double click region selection + current_time = time.time() + if current_time - last_click_time < 0.5: + # Double click detected in empty space, test regions + pt = (world_mouse_pos.x, world_mouse_pos.y) + clicked_region = -1 + for r_idx, region in enumerate(regions): + poly_pts = [vertices[i] for i in region['vertex_indices']] + if len(poly_pts) >= 3: + # pointPolygonTest needs float32 numpy array + poly_arr = np.array(poly_pts, dtype=np.float32) + dist = cv2.pointPolygonTest(poly_arr, pt, False) + if dist >= 0: + clicked_region = r_idx + break + if clicked_region != -1: + selected_region_idx = clicked_region + selection_time = current_time + empty_selection_origin = None + print(f"Region {selected_region_idx} selected for deletion") + else: + selected_region_idx = -1 + empty_selection_origin = pr.Vector2(world_mouse_pos.x, world_mouse_pos.y) + empty_selection_time = current_time + print(f"Empty space selected for creation at {empty_selection_origin.x}, {empty_selection_origin.y}") + last_click_time = current_time + + elif pr.is_mouse_button_released(pr.MOUSE_BUTTON_LEFT): + dragging_vertex_idx = -1 + + # Handle Dragging + if dragging_vertex_idx != -1: + vertices[dragging_vertex_idx][0] = world_mouse_pos.x + vertices[dragging_vertex_idx][1] = world_mouse_pos.y + + # Camera Panning (Right Click) + if pr.is_mouse_button_down(pr.MOUSE_BUTTON_RIGHT): + delta = pr.get_mouse_delta() + delta.x = delta.x * -1.0 / camera.zoom + delta.y = delta.y * -1.0 / camera.zoom + camera.target = pr.vector2_add(camera.target, delta) + + # Camera Panning (Arrow Keys) + pan_speed = 10.0 / camera.zoom + if pr.is_key_down(pr.KEY_RIGHT): + camera.target.x += pan_speed + if pr.is_key_down(pr.KEY_LEFT): + camera.target.x -= pan_speed + if pr.is_key_down(pr.KEY_DOWN): + camera.target.y += pan_speed + if pr.is_key_down(pr.KEY_UP): + camera.target.y -= pan_speed + + # Camera Zooming (Scroll) + wheel = pr.get_mouse_wheel_move() + if wheel != 0: + mouse_world_pos = pr.get_screen_to_world_2d(pr.get_mouse_position(), camera) + camera.offset = pr.get_mouse_position() + camera.target = mouse_world_pos + camera.zoom += wheel * 0.1 + if camera.zoom < 0.1: camera.zoom = 0.1 + + # Deletion logic + if pr.is_key_pressed(pr.KEY_D): + if selected_region_idx != -1 and time.time() - selection_time < 5.0: + print(f"Deleting region {selected_region_idx}") + regions.pop(selected_region_idx) + selected_region_idx = -1 + + # Garbage Collect unreferenced vertices to clean up visual clutter + used_indices = set() + for r in regions: + used_indices.update(r['vertex_indices']) + + # We must rebuild the vertices array to exclude orphans and remap region indices + new_vertices = [] + index_map = {} # old_idx -> new_idx + + for old_idx, v in enumerate(vertices): + if old_idx in used_indices: + new_idx = len(new_vertices) + new_vertices.append(v) + index_map[old_idx] = new_idx + + # Reassign vertices mapping + vertices = new_vertices + # Remap region references + for r in regions: + r['vertex_indices'] = [index_map[idx] for idx in r['vertex_indices']] + + print(f"Garbage collection removed {len(index_map) - len(new_vertices)} orphaned vertices.") + + # Creation Logic + if pr.is_key_pressed(pr.KEY_N): + current_time = time.time() + if empty_selection_origin is not None and current_time - empty_selection_time < 5.0: + print("Creating new region...") + # Create 50x50 square centered at empty_selection_origin + ox, oy = empty_selection_origin.x, empty_selection_origin.y + half_size = 25.0 + + new_pts = [ + [ox - half_size, oy - half_size], + [ox + half_size, oy - half_size], + [ox + half_size, oy + half_size], + [ox - half_size, oy + half_size] + ] + + new_indices = [] + for pt in new_pts: + new_indices.append(len(vertices)) + vertices.append(pt) + + # Find new unique ID + existing_ids = [r['original_id'] for r in regions] + new_uid = max(existing_ids) + 1 if existing_ids else 1 + + color = pr.Color(np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255), 255) + + regions.append({ + 'original_id': new_uid, + 'vertex_indices': new_indices, + 'color': color + }) + + empty_selection_origin = None # consume the selection + print(f"Created region {new_uid}") + elif selected_region_idx != -1 and current_time - selection_time < 5.0: + print(f"Adding vertex to region {selected_region_idx}...") + + region = regions[selected_region_idx] + indices = region['vertex_indices'] + + # Find the closest edge to insert the new vertex + min_dist = float('inf') + best_insert_idx = -1 + best_insert_pt = None + + m_pt = np.array([world_mouse_pos.x, world_mouse_pos.y]) + + for i in range(len(indices)): + idx1 = indices[i] + idx2 = indices[(i+1) % len(indices)] + + v1 = np.array(vertices[idx1]) + v2 = np.array(vertices[idx2]) + + # Compute distance from point to line segment + l2 = np.sum((v1 - v2)**2) + if l2 == 0.0: + dist = np.linalg.norm(m_pt - v1) + proj_pt = v1 + else: + t = max(0, min(1, np.dot(m_pt - v1, v2 - v1) / l2)) + proj_pt = v1 + t * (v2 - v1) + dist = np.linalg.norm(m_pt - proj_pt) + + if dist < min_dist: + min_dist = dist + best_insert_idx = (i + 1) % len(indices) + best_insert_pt = proj_pt.tolist() + + if best_insert_idx != -1 and best_insert_pt is not None: + # Insert the new vertex into the global array + new_v_idx = len(vertices) + vertices.append(best_insert_pt) + + # Insert the reference into the region's topological loop + # We insert at `best_insert_idx` to split the edge + if best_insert_idx == 0: + # if it's the wrap-around edge, append to end + indices.append(new_v_idx) + else: + indices.insert(best_insert_idx, new_v_idx) + + print(f"Added vertex to region {selected_region_idx} at {best_insert_pt}") + + # Saving function + if pr.is_key_pressed(pr.KEY_S): + if pr.is_key_down(pr.KEY_LEFT_CONTROL): + out_path = save_file_dialog() + if out_path: + save_segmentation(out_path) + else: + # Default S saves to original seg path if possible, else tmp + out = "tmp_modified_seg.npy" + if seg_path: + # Actually, user said keep original 's' for tmp save + pass + save_segmentation(out) + + pr.begin_drawing() + pr.clear_background(pr.RAYWHITE) + + if bg_texture: + pr.begin_mode_2d(camera) + # Draw Background + pr.draw_texture(bg_texture, 0, 0, pr.WHITE) + else: + pr.draw_text("No Image Loaded", pr.get_screen_width()//2 - 100, pr.get_screen_height()//2 - 20, 20, pr.GRAY) + pr.draw_text("Drag and drop an image file here", pr.get_screen_width()//2 - 150, pr.get_screen_height()//2 + 10, 15, pr.LIGHTGRAY) + pr.begin_mode_2d(camera) + + # Draw Region Boundaries + current_time_render = time.time() + for idx, region in enumerate(regions): + indices = region['vertex_indices'] + if len(indices) < 2: + continue + + color = region['color'] + + # Draw filled context if selected for deletion + if idx == selected_region_idx and (current_time_render - selection_time) < 5.0: + poly_pts = [vertices[i] for i in indices] + poly_arr = np.array(poly_pts, dtype=np.float32) + try: + triangles = earcut.triangulate_float32(poly_arr, np.array([len(poly_pts)], dtype=np.uint32)) + fill_color = pr.Color(color.r, color.g, color.b, 100) + for i in range(0, len(triangles), 3): + p1 = pr.Vector2(poly_pts[triangles[i]][0], poly_pts[triangles[i]][1]) + p2 = pr.Vector2(poly_pts[triangles[i+1]][0], poly_pts[triangles[i+1]][1]) + p3 = pr.Vector2(poly_pts[triangles[i+2]][0], poly_pts[triangles[i+2]][1]) + # Earcut often generates clockwise, draw backwards + pr.draw_triangle(p1, p3, p2, fill_color) + except Exception: + pass + + # draw line strip manually + for i in range(len(indices)): + idx1 = indices[i] + idx2 = indices[(i+1) % len(indices)] # wrap around + v1 = vertices[idx1] + v2 = vertices[idx2] + p1 = pr.Vector2(v1[0], v1[1]) + p2 = pr.Vector2(v2[0], v2[1]) + + # Make lines thicker depending on zoom to be visible + line_thick = max(1.0, 2.0 / camera.zoom) + pr.draw_line_ex(p1, p2, line_thick, color) + + # Draw Vertices + # Only draw vertices if zoomed in enough, to prevent clutter on full view + if camera.zoom > 0.5: + vert_radius = max(2.0, 3.0 / camera.zoom) + for i, (vx, vy) in enumerate(vertices): + color = pr.RED if (i == hovered_vertex_idx or i == dragging_vertex_idx) else pr.BLUE + pr.draw_circle_v(pr.Vector2(vx, vy), vert_radius, color) + + # Draw Creation Crosshair + if empty_selection_origin is not None and (time.time() - empty_selection_time) < 5.0: + ch_size = 10.0 / camera.zoom + ch_thick = max(1.0, 2.0 / camera.zoom) + p_center = empty_selection_origin + pr.draw_line_ex(pr.Vector2(p_center.x - ch_size, p_center.y), pr.Vector2(p_center.x + ch_size, p_center.y), ch_thick, pr.RED) + pr.draw_line_ex(pr.Vector2(p_center.x, p_center.y - ch_size), pr.Vector2(p_center.x, p_center.y + ch_size), ch_thick, pr.RED) + + pr.end_mode_2d() + + # UI Overlay + pr.draw_text("Segmentation Point Editor", 10, 10, 20, pr.BLACK) + pr.draw_text("Left Click + Drag point: Move boundary", 10, 40, 10, pr.DARKGRAY) + pr.draw_text("Double Left Click: Select mask/empty space", 10, 55, 10, pr.DARKGRAY) + pr.draw_text("'D' Key: Delete selected mask | 'N' Key: Create mask", 10, 70, 10, pr.DARKGRAY) + pr.draw_text("'Ctrl+O': Open | 'S': Tmp Save | 'Ctrl+S': Save As", 10, 85, 10, pr.DARKGRAY) + pr.draw_text("Right Click / Arrows: Pan camera | Mouse Wheel: Zoom", 10, 100, 10, pr.DARKGRAY) + + pr.end_drawing() + + if bg_texture: + pr.unload_texture(bg_texture) + pr.close_window() + +if __name__ == "__main__": + main() |
