说明
官网地址:https://github.com/visinf/INSID3
在计算机视觉领域,图像分割一直是一项极具挑战性的任务。传统方法往往需要大量标注数据、长时间模型训练,以及针对不同物体类别的反复微调。但今天,这一切将被彻底改变——INSID3 横空出世,它以一种全新的 In-Context(上下文学习) 范式,让你仅需提供一张参考图像和其掩码,就能在任意目标图像中精准分割出同款物体,完全无需任何训练!
更令人兴奋的是,通过 ONNX Runtime 的加持,我们可以将 INSID3 高效地部署 到生产环境中,实现跨平台、跨语言的高性能推理。本文将从零开始,带你了解 INSID3 的原理、优势,并手把手教你如何使用 Python + ONNX Runtime 快速部署 INSID3,让你轻松拥有一次标注、处处分割的“超能力”。
INSID3(In-context Segmentation with a Non‑trainable DINOv3 encoder)是一种无需训练的分割模型,它的核心思想非常简单:
利用一个冻结的大型视觉编码器(DINOv3)提取图像特征,然后通过巧妙的特征匹配和聚类机制,将参考物体的语义“传递”到目标图像上。
换句话说,你只需要在参考图上勾画或选择一个物体,INSID3 就能自动在其他图像中找到同样的物体,并生成精确的分割掩码。整个过程没有训练、没有梯度下降、不依赖特定类别,真正做到了开箱即用、万物皆可分割。
效果
模型信息
Model Properties----------------------------------------------------------------------------------------Inputs-------------------------name:imgstensor:Float[-1, 3, 1024, 1024]---------------------------------------------------------------Outputs-------------------------name:f_normtensor:Float[-1, 1024, 64, 64]name:f_debiastensor:Float[-1, 1024, 64, 64]---------------------------------------------------------------
代码
class INSID3App: def __init__(self, root, onnx_path="insid3_encoder_nobatch.onnx"): self.root = root self.root.title("INSID3 交互式分割") self.onnx_path = onnx_path self.ref_orig = None self.tgt_orig = None self.ref_disp = None self.tgt_disp = None self.mask_np = None self.mask_pil = None self.result = None self.display_size = 500 self.points = [] self.cached_model = None self._build_ui() def _build_ui(self): toolbar = tk.Frame(self.root) toolbar.pack(side=tk.TOP, fill=tk.X, padx=5, pady=5)for text, cmd in [("加载参考图", self.load_ref), ("加载目标图", self.load_tgt), ("清除多边形", self.clear_poly), ("生成掩码", self.gen_mask), ("缓存参考", self.cache_ref), ("推理", self.run_inference), ("保存结果", self.save_result)]: tk.Button(toolbar, text=text, command=cmd).pack(side=tk.LEFT, padx=2) self.info_var = tk.StringVar(value="选点: 0") tk.Label(toolbar, textvariable=self.info_var).pack(side=tk.LEFT, padx=10) cf = tk.Frame(self.root); cf.pack(fill=tk.BOTH, expand=True) self.cv_ref = tk.Canvas(cf, width=self.display_size, height=self.display_size,bg='gray', cursor="cross") self.cv_ref.grid(row=0, column=0, padx=5) tk.Label(cf, text="参考图 (左键加点 / 右键删点)").grid(row=1, column=0) self.cv_tgt = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray') self.cv_tgt.grid(row=0, column=1, padx=5) tk.Label(cf, text="目标图").grid(row=1, column=1) self.cv_res = tk.Canvas(cf, width=self.display_size, height=self.display_size, bg='gray') self.cv_res.grid(row=0, column=2, padx=5) tk.Label(cf, text="分割结果").grid(row=1, column=2) self.cv_ref.bind("<Button-1>", self.add_point) self.cv_ref.bind("<Button-3>", self.rm_last) self.prog = ttk.Progressbar(self.root, mode='indeterminate') self.prog.pack(fill=tk.X) self.stvar = tk.StringVar(value="就绪") tk.Label(self.root, textvariable=self.stvar, bd=1, relief=tk.SUNKEN, anchor=tk.W).pack(fill=tk.X) def load_ref(self): path = filedialog.askopenfilename(filetypes=[("Image", "*.jpg *.jpeg *.png *.bmp")])if not path: return self.ref_orig = Image.open(path).convert("RGB") self.ref_disp, _ = letterbox_image(self.ref_orig, self.display_size) self.points.clear(); self.mask_np = None; self.cached_model = None self._redraw(); self.stvar.set(f"参考: {path}") def _redraw(self):if self.ref_disp is None: return base = self.ref_disp.copy().convert("RGBA")if self.mask_np is not None and self.mask_np.sum() > 0: base = self._overlay(base, self.mask_np) draw = ImageDraw.Draw(base)for x,y in self.points: r = 4; draw.ellipse((x-r, y-r, x+r, y+r), fill='green')if len(self.points) > 1: draw.line(self.points, fill='yellow', width=2) self.ref_tk = ImageTk.PhotoImage(base) self.cv_ref.delete("all"); self.cv_ref.create_image(0,0,anchor=tk.NW, image=self.ref_tk) self.info_var.set(f"选点:{len(self.points)}") def _overlay(self, base, mask_1024): small = Image.fromarray(mask_1024).resize((self.display_size, self.display_size), Image.NEAREST) ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0))for y in range(self.display_size):for x in range(self.display_size):if small.getpixel((x,y)) > 128: ov.putpixel((x,y), (255,0,0,100))return Image.alpha_composite(base, ov) def load_tgt(self): path = filedialog.askopenfilename(filetypes=[("Image", "*.jpg *.jpeg *.png *.bmp")])if not path: return self.tgt_orig = Image.open(path).convert("RGB") self.tgt_disp, _ = letterbox_image(self.tgt_orig, self.display_size) self.tgt_tk = ImageTk.PhotoImage(self.tgt_disp) self.cv_tgt.delete("all"); self.cv_tgt.create_image(0,0,anchor=tk.NW, image=self.tgt_tk) self.stvar.set(f"目标: {path}") def add_point(self, event): self.points.append((event.x, event.y)); self._redraw() def rm_last(self, event):if self.points: self.points.pop(); self._redraw() def clear_poly(self): self.points.clear(); self._redraw() def gen_mask(self):if len(self.points) < 3: messagebox.showwarning("警告", "至少需要3个顶点"); return scale_x = 1024 / self.display_size scale_y = 1024 / self.display_size pts = [(int(x*scale_x), int(y*scale_y)) for x,y in self.points] mask = np.zeros((1024,1024), dtype=np.uint8) cv2.fillPoly(mask, [np.array(pts, dtype=np.int32)], 255) self.mask_np = mask self.mask_pil = Image.fromarray(mask, 'L') self._redraw(); self.stvar.set("掩码已生成") def cache_ref(self):if self.ref_orig is None or self.mask_pil is None: messagebox.showwarning("缺少数据", "请先加载参考图并生成掩码"); return try: self.cached_model = OnnxINSID3(self.onnx_path, device='cuda') self.cached_model.encode_reference(self.ref_orig, self.mask_pil) self.stvar.set("参考特征已缓存,可连续推理") except Exception as e: messagebox.showerror("缓存失败", str(e)) def run_inference(self):if self.ref_orig is None or self.mask_pil is None or self.tgt_orig is None: messagebox.showwarning("缺少数据", "请加载参考图、掩码和目标图"); return self.cv_res.delete("all"); self.prog.start(); self.stvar.set("推理中...") threading.Thread(target=self._inf, daemon=True).start() def _inf(self): start = time.time() try:if self.cached_model and self.cached_model._cached_ref_norm is not None: result_pil = self.cached_model.segment_with_cache(self.tgt_orig)else: model = OnnxINSID3(self.onnx_path, device='cuda') model.set_reference(self.ref_orig, self.mask_pil) model.set_target(self.tgt_orig) result_pil = model.segment() self.result = result_pil elapsed = time.time() - start self.root.after(0, self._show) self.root.after(0, lambda e=elapsed: self.stvar.set(f"完成,耗时 {e:.2f}s")) except Exception as e: self.root.after(0, lambda e=e: messagebox.showerror("推理错误", str(e))) finally: self.root.after(0, self.prog.stop) def _show(self):if self.result is None: return tgt_disp = self.tgt_orig.resize((self.display_size, self.display_size)) res_disp = self.result.resize((self.display_size, self.display_size)) base = tgt_disp.convert("RGBA") ov = Image.new("RGBA", (self.display_size, self.display_size), (0,0,0,0)) arr = np.array(res_disp) > 128for y in range(self.display_size):for x in range(self.display_size):if arr[y,x]: ov.putpixel((x,y), (255,0,0,100)) base = Image.alpha_composite(base, ov) self.res_tk = ImageTk.PhotoImage(base) self.cv_res.delete("all"); self.cv_res.create_image(0,0,anchor=tk.NW, image=self.res_tk) def save_result(self):if self.result is None: messagebox.showwarning("无结果", "请先推理"); return path = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG", "*.png")])if path: self.result.save(path); messagebox.showinfo("保存成功", f"已保存至 {path}")if __name__ == "__main__": root = tk.Tk()# app = INSID3App(root, onnx_path="insid3_encoder_nobatch.onnx") app = INSID3App(root, onnx_path="insid3_encoder_nobatch_simplified.onnx") root.mainloop()