import numpy as npimport matplotlib.pyplot as pltfrom PIL import Image, ImageEnhance, ImageFilterimport osplt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial']plt.rcParams['axes.unicode_minus'] = Falseclass DataAugmentor: def __init__(self, image_path): """ 初始化增强器 参数: image_path: 图片路径 """ self.original = Image.open(image_path).convert('RGB') self.results = {'原图': self.original} print(f"图片尺寸: {self.original.size}") print(f"图片模式: {self.original.mode}") def horizontal_flip(self, key_name="水平翻转"): """水平镜像翻转""" self.results[key_name] = self.original.transpose(Image.FLIP_LEFT_RIGHT) return self def vertical_flip(self, key_name="垂直翻转"): """垂直镜像翻转""" self.results[key_name] = self.original.transpose(Image.FLIP_TOP_BOTTOM) return self def rotate(self, angle=30, key_name=None): """ 旋转图片 参数: angle: 旋转角度(逆时针) """ if key_name is None: key_name = f"旋转{angle}°" self.results[key_name] = self.original.rotate( angle, expand=False, fillcolor='black' ) return self def random_crop(self, crop_ratio=0.8, key_name=None): """ 随机裁剪后缩放回原尺寸 参数: crop_ratio: 裁剪比例(0~1) """ if key_name is None: key_name = f"随机裁剪({int(crop_ratio*100)}%)" w, h = self.original.size crop_w = int(w * crop_ratio) crop_h = int(h * crop_ratio) # 随机选择裁剪起点 left = np.random.randint(0, w - crop_w + 1) top = np.random.randint(0, h - crop_h + 1) cropped = self.original.crop((left, top, left + crop_w, top + crop_h)) # 缩放回原尺寸 self.results[key_name] = cropped.resize((w, h), Image.LANCZOS) return self def adjust_brightness(self, factor=1.5, key_name=None): """ 调整亮度 参数: factor: >1变亮, <1变暗 """ if key_name is None: key_name = f"亮度×{factor}" enhancer = ImageEnhance.Brightness(self.original) self.results[key_name] = enhancer.enhance(factor) return self def adjust_contrast(self, factor=1.5, key_name=None): """ 调整对比度 参数: factor: >1增强对比度, <1降低对比度 """ if key_name is None: key_name = f"对比度×{factor}" enhancer = ImageEnhance.Contrast(self.original) self.results[key_name] = enhancer.enhance(factor) return self def adjust_saturation(self, factor=1.8, key_name=None): """ 调整饱和度 参数: factor: >1更鲜艳, <1更灰暗, 0为灰度图 """ if key_name is None: key_name = f"饱和度×{factor}" enhancer = ImageEnhance.Color(self.original) self.results[key_name] = enhancer.enhance(factor) return self def add_gaussian_noise(self, mean=0, std=25, key_name=None): """ 添加高斯噪声 参数: mean: 噪声均值 std: 噪声标准差(越大噪声越强) """ if key_name is None: key_name = f"高斯噪声(σ={std})" img_array = np.array(self.original, dtype=np.float32) noise = np.random.normal(mean, std, img_array.shape) noisy = img_array + noise noisy = np.clip(noisy, 0, 255).astype(np.uint8) self.results[key_name] = Image.fromarray(noisy) return self def gaussian_blur(self, radius=3, key_name=None): """ 高斯模糊,模拟失焦或运动模糊 参数: radius: 模糊半径,越大越模糊 """ if key_name is None: key_name = f"高斯模糊(r={radius})" self.results[key_name] = self.original.filter( ImageFilter.GaussianBlur(radius=radius) ) return self def color_jitter(self, brightness=1.3, contrast=1.3, saturation=1.5, key_name="颜色抖动"): """ 组合颜色变换,模拟不同光照环境 """ img = self.original.copy() # 依次调整亮度、对比度、饱和度 img = ImageEnhance.Brightness(img).enhance(brightness) img = ImageEnhance.Contrast(img).enhance(contrast) img = ImageEnhance.Color(img).enhance(saturation) self.results[key_name] = img return self def translate(self, dx=30, dy=20, key_name=None): """ 平移变换 参数: dx: 水平位移(像素) dy: 垂直位移(像素) """ if key_name is None: key_name = f"平移({dx},{dy})" img_array = np.array(self.original) h, w = img_array.shape[:2] # 创建平移矩阵并应用 M = np.float32([[1, 0, dx], [0, 1, dy]]) # 使用PIL的transform实现 self.results[key_name] = self.original.transform( (w, h), Image.AFFINE, (1, 0, dx, 0, 1, dy), fillcolor='black' ) return self def show_all(self, cols=4, figsize=(16, 12), save_path=None): """ 可视化所有增强结果 参数: cols: 每行显示的图片数 figsize: 画布大小 save_path: 保存路径 """ n = len(self.results) rows = (n + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=figsize) axes = axes.flatten() if n > 1 else [axes] for idx, (name, img) in enumerate(self.results.items()): axes[idx].imshow(img) axes[idx].set_title(name, fontsize=14, fontweight='bold') axes[idx].axis('off') # 隐藏多余的子图 for idx in range(n, len(axes)): axes[idx].axis('off') fig.suptitle('数据增强效果总览', fontsize=20, fontweight='bold', y=0.98) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white') print(f"图片已保存至: {save_path}") plt.show() return figdef demo_with_custom_image(image_path="demo.jpg"): """ 对指定图片进行完整的数据增强演示 如果没有demo.jpg,程序会自动创建一个彩色几何图形作为演示 """ # 检查图片是否存在,不存在则生成一个演示图 if not os.path.exists(image_path): print(f"未找到 '{image_path}',正在生成演示图片...") _create_demo_image(image_path) print("=" * 50) print("数据增强演示开始") print("=" * 50) # 创建增强器并链式调用所有方法 augmentor = DataAugmentor(image_path) # 几何变换 augmentor.horizontal_flip() # 水平翻转 augmentor.vertical_flip() # 垂直翻转 augmentor.rotate(angle=25) # 旋转25° augmentor.rotate(angle=-15) # 旋转-15° augmentor.random_crop(0.75) # 75%区域随机裁剪 augmentor.translate(40, -20) # 平移 # 颜色变换 augmentor.adjust_brightness(0.5) # 变暗 augmentor.adjust_brightness(1.8) # 变亮 augmentor.adjust_contrast(0.4) # 低对比度 augmentor.adjust_contrast(2.0) # 高对比度 augmentor.adjust_saturation(0.3) # 低饱和度 augmentor.adjust_saturation(2.5) # 高饱和度 augmentor.color_jitter(1.4, 1.3, 1.6) # 组合颜色变换 # 噪声与模糊 augmentor.add_gaussian_noise(std=20) # 轻微噪声 augmentor.add_gaussian_noise(std=45) # 较强噪声 augmentor.gaussian_blur(radius=2) # 轻微模糊 augmentor.gaussian_blur(radius=5) # 较强模糊 print(f"\n✅ 共生成 {len(augmentor.results)} 种变体") print("正在绘制可视化结果...\n") # 展示所有效果 augmentor.show_all(cols=4, figsize=(18, 12), save_path="data_augmentation_demo.png") print("\n统计信息:") print(f" - 原始图片: 1张") print(f" - 增强变体: {len(augmentor.results) - 1}张") print(f" - 总数据量提升: {len(augmentor.results)}倍") return augmentordef _create_demo_image(save_path): """创建一个彩色的演示图片""" from PIL import ImageDraw, ImageFont w, h = 400, 400 img = Image.new('RGB', (w, h), color='white') draw = ImageDraw.Draw(img) # 画一个彩色圆形 draw.ellipse([80, 80, 320, 320], fill='coral', outline='darkred', width=5) # 画一个矩形 draw.rectangle([120, 120, 280, 200], fill='skyblue', outline='navy', width=3) # 画几个彩色圆点 for i in range(5): x = 100 + i * 50 y = 300 draw.ellipse([x-15, y-15, x+15, y+15], fill=['gold', 'limegreen', 'purple', 'orange', 'pink'][i]) # 写文字 draw.text((120, 350), "Data Augmentation", fill='darkblue') img.save(save_path) print(f"演示图片已保存为: {save_path}")if __name__ == "__main__": # 运行演示,将你的图片路径替换 'demo.jpg' # 例:demo_with_custom_image("my_cat.jpg") result = demo_with_custom_image("demo.jpg")