跳转到主要内容
TjMakeBot 博客tjmakebot.com

YOLO数据集制作完整指南:从零到模型训练

TjMakeBot 团队技术教程12 分钟
技术教程实用方法
YOLO数据集制作完整指南:从零到模型训练

🎯 引言:YOLO 为什么这么火?

"我想用YOLO做一个目标检测项目,但不知道从哪里开始..."

这是很多AI开发者的真实困惑。YOLO(You Only Look Once)是目标检测领域广泛使用的算法之一,从YOLOv1到最新的YOLOv10,YOLO系列算法在速度和准确率之间取得了较好的平衡。

YOLO的应用场景

  • 🚗 自动驾驶:实时检测车辆、行人、交通标志
  • 🏭 工业质检:快速检测产品缺陷
  • 🏥 医疗影像:辅助医生识别病灶
  • 🛒 零售分析:商品识别和库存管理
  • 🔒 安防监控:实时监控和异常检测

YOLO的优势

  • 速度快:可以实时处理视频流
  • 准确率高:在速度和准确率之间取得平衡
  • 易用性好:有完善的工具和文档
  • 社区活跃:有大量教程和案例

但很多开发者在使用YOLO时遇到的第一道坎就是:如何制作高质量的YOLO数据集?

今天,我们将从零开始,手把手教你制作一个完整的YOLO数据集,直到模型训练成功。无论你是初学者还是有经验的开发者,都能从这篇文章中找到实用的方法和技巧。

📚 什么是 YOLO 数据集?

YOLO 数据格式

YOLO 使用一种简洁的文本格式来存储标注信息:

文件结构

dataset/
├── images/
│   ├── train/
│   │   ├── image001.jpg
│   │   ├── image002.jpg
│   │   └── ...
│   └── val/
│       ├── image101.jpg
│       └── ...
└── labels/
    ├── train/
    │   ├── image001.txt
    │   ├── image002.txt
    │   └── ...
    └── val/
        ├── image101.txt
        └── ...

标注文件格式image001.txt):

class_id center_x center_y width height
0 0.5 0.5 0.3 0.4
1 0.2 0.3 0.1 0.2

格式说明

  • class_id:类别 ID(从 0 开始)
  • center_x, center_y:边界框中心点的归一化坐标(0-1)
  • width, height:边界框的归一化宽度和高度(0-1)

关键点:YOLO 使用归一化坐标,所有坐标值都在 0-1 之间。

YOLO 版本差异

不同版本的 YOLO 对数据集格式的要求略有不同:

版本 格式要求 特殊说明
YOLOv5 标准格式 支持自定义类别数
YOLOv8 标准格式 推荐使用 Ultralytics 格式
YOLOv9 标准格式 兼容 YOLOv5 格式
YOLOv10 标准格式 最新版本,性能最优

好消息:所有 YOLO 版本都使用相同的数据格式,你的数据集可以通用!

🛠️ 步骤 1:数据收集与准备

1.1 确定数据集需求

在开始之前,明确你的需求是成功的第一步。一个清晰的需求规划可以帮你节省大量时间和成本。

需求分析清单

1. 目标类别定义

明确检测目标

  • 列出所有需要检测的对象类别
  • 定义每个类别的边界(什么算,什么不算)
  • 考虑类别的层次结构(如:车辆 → 汽车、卡车、公交车)

真实案例

某交通监控项目最初只定义了"车辆"一个类别,训练后发现模型无法区分汽车和卡车。后来细分为"car"、"truck"、"bus"、"motorcycle"四个类别,模型准确率提升了15%。

类别数量建议

  • 简单项目:1-5个类别(适合初学者)
  • 中等项目:5-20个类别(常见应用)
  • 复杂项目:20+个类别(需要更多数据和标注时间)

2. 数据规模规划

数据量估算

项目类型 每类最少图片 推荐图片数 总图片数(5类)
快速原型 100-200张 500张 2,500张
生产应用 1,000张 3,000张 15,000张
高精度应用 5,000张 10,000张 50,000张

数据量影响因素

  • 类别数量:类别越多,需要的数据越多
  • 场景复杂度:复杂场景需要更多数据
  • 精度要求:高精度要求需要更多高质量数据
  • 类别平衡:确保各类别数据量相对平衡(比例不超过10:1)

真实案例

某工业质检项目需要检测10种缺陷类型。正常产品有10,000张,但缺陷样本只有500张。通过主动收集缺陷样本和使用数据增强,最终每类缺陷样本达到2,000张,模型准确率从75%提升到92%。

3. 场景多样性规划

场景覆盖维度

时间维度

  • 白天、夜晚、黄昏、黎明
  • 不同季节(春夏秋冬)
  • 不同时间段(早中晚)

天气维度

  • 晴天、雨天、雪天、雾天
  • 不同光照条件(强光、阴影、逆光)

环境维度

  • 室内、室外
  • 城市、乡村、高速公路
  • 不同背景复杂度

目标状态维度

  • 静止、运动
  • 完整、部分遮挡
  • 不同角度(正面、侧面、背面)

场景多样性检查清单

  • ✅ 至少覆盖3-5种主要场景
  • ✅ 包含边界案例(极端情况)
  • ✅ 避免场景过于单一(容易过拟合)
  • ✅ 确保训练集和测试集场景分布一致

4. 图片质量要求

分辨率要求

应用场景 最低分辨率 推荐分辨率 说明
小目标检测 1280×1280 1920×1920+ 需要更高分辨率识别小目标
标准检测 640×640 1280×1280 YOLO默认输入尺寸
快速检测 416×416 640×640 速度优先,精度可接受

图片质量检查

  • ✅ 清晰度:目标对象清晰可见,无模糊
  • ✅ 对比度:目标与背景对比明显
  • ✅ 色彩:色彩真实,无严重失真
  • ✅ 曝光:曝光正常,不过曝或欠曝
  • ✅ 格式:统一格式(JPG或PNG),避免格式混乱

5. 预算和时间规划

时间估算(以5类、每类1000张为例):

阶段 时间估算 说明
数据收集 1-2周 根据数据来源不同
数据标注 2-4周 使用AI辅助可缩短到1周
质量检查 3-5天 多轮检查
格式转换 1天 自动化处理
总计 4-7周 使用AI辅助可缩短到2-3周

成本估算(以5类、每类1000张为例):

方案 标注成本 工具成本 总成本
纯人工标注 $8,000-12,000 $0 $8,000-12,000
AI辅助标注 $1,600-2,400 $0(免费工具) $1,600-2,400
节省 80% - 80%

需求文档模板

# YOLO数据集需求文档

## 项目信息
- 项目名称:[项目名称]
- 应用场景:[应用场景描述]
- 目标精度:[目标mAP值]

## 类别定义
1. [类别1]:[详细定义]
2. [类别2]:[详细定义]
...

## 数据规模
- 类别数量:[N]个
- 每类图片数:[M]张
- 总图片数:[N×M]张

## 场景要求
- 时间:[白天/夜晚/全天]
- 天气:[晴天/雨天/全天候]
- 环境:[室内/室外/混合]

## 质量要求
- 分辨率:[最低分辨率]
- 标注精度:[IoU要求]
- 类别准确率:[准确率要求]

## 时间计划
- 开始时间:[日期]
- 完成时间:[日期]
- 里程碑:[关键节点]

## 预算
- 标注成本:[预算]
- 工具成本:[预算]
- 总预算:[总预算]

1.2 收集图片数据:数据来源全攻略

数据来源1:公开数据集(适合快速开始)

公开数据集是快速开始项目的首选,特别适合学习和原型开发。

主流公开数据集对比

数据集 类别数 图片数 标注数 特点 适用场景
COCO 80 330K 2.5M 质量高,标注精确 通用目标检测
Open Images 600 9M 36M 类别多,数据量大 大规模训练
ImageNet 1000 14M - 分类数据集 预训练模型
Pascal VOC 20 11K 27K 经典数据集 学习研究
Cityscapes 30 25K - 城市街景 自动驾驶

COCO数据集详细说明

下载方法

# 方法1:官方下载
# 访问 https://cocodataset.org/#download
# 下载 train2017.zip, val2017.zip, annotations_trainval2017.zip

# 方法2:使用API下载
from pycocotools.coco import COCO
import requests

# 下载图片和标注

类别列表(部分):

  • 人物:person
  • 车辆:car, truck, bus, motorcycle, bicycle
  • 动物:cat, dog, horse, cow, elephant
  • 家具:chair, couch, bed, table
  • 电子设备:laptop, mouse, keyboard, cell phone

转换为YOLO格式

使用Python脚本转换

from pycocotools.coco import COCO
import json
import os
from PIL import Image

def coco_to_yolo(coco_annotation_file, output_dir):
    """
    将COCO格式转换为YOLO格式
    """
    coco = COCO(coco_annotation_file)
    
    # 创建输出目录
    os.makedirs(f'{output_dir}/images', exist_ok=True)
    os.makedirs(f'{output_dir}/labels', exist_ok=True)
    
    # 获取所有图片ID
    img_ids = coco.getImgIds()
    
    for img_id in img_ids:
        # 获取图片信息
        img_info = coco.loadImgs(img_id)[0]
        img_width = img_info['width']
        img_height = img_info['height']
        
        # 获取该图片的所有标注
        ann_ids = coco.getAnnIds(imgIds=img_id)
        anns = coco.loadAnns(ann_ids)
        
        # 创建YOLO格式的标注文件
        label_file = f"{output_dir}/labels/{img_info['file_name'].replace('.jpg', '.txt')}"
        with open(label_file, 'w') as f:
            for ann in anns:
                # 获取类别ID(YOLO从0开始)
                class_id = ann['category_id'] - 1  # COCO从1开始
                
                # 获取边界框(COCO格式:x, y, width, height)
                bbox = ann['bbox']
                x, y, w, h = bbox
                
                # 转换为YOLO格式(归一化中心点坐标和宽高)
                center_x = (x + w / 2) / img_width
                center_y = (y + h / 2) / img_height
                norm_w = w / img_width
                norm_h = h / img_height
                
                # 写入文件
                f.write(f"{class_id} {center_x} {center_y} {norm_w} {norm_h}\n")
        
        # 复制图片
        # ... (复制图片到images目录)

# 使用
coco_to_yolo('annotations/instances_train2017.json', 'yolo_dataset')

优势

  • ✅ 数据量大,质量高
  • ✅ 标注精确,经过专业审核
  • ✅ 免费使用,无版权问题
  • ✅ 社区支持,有大量教程
  • ✅ 适合快速开始和原型开发

劣势

  • ⚠️ 可能不符合你的具体应用场景
  • ⚠️ 类别可能不够细分
  • ⚠️ 场景可能不够多样化
  • ⚠️ 需要筛选和转换格式

使用建议

  • 适合快速验证想法
  • 适合作为预训练数据
  • 适合学习YOLO使用
  • 不适合生产环境(除非完全匹配你的场景)

数据来源2:自己拍摄(推荐,适合特定场景)

自己拍摄是最可靠的数据来源,可以完全控制数据质量和场景覆盖。

拍摄计划制定

1. 场景覆盖计划

时间覆盖

  • 白天:上午(8-12点)、下午(12-18点)
  • 夜晚:傍晚(18-20点)、深夜(20-24点)
  • 特殊时间:黄昏、黎明、正午强光

拍摄建议

  • 每个时间段至少拍摄100-200张
  • 确保不同时间段的场景多样性
  • 记录拍摄时间和光照条件

天气覆盖

  • 晴天:正常光照,清晰可见
  • 雨天:湿滑路面,反光效果
  • 阴天:柔和光照,无强烈阴影
  • 雾天:能见度低,目标模糊

拍摄建议

  • 每种天气至少拍摄200-300张
  • 注意天气对目标外观的影响
  • 考虑极端天气情况

角度覆盖

  • 正面:0度,目标完整可见
  • 侧面:45度、90度,部分遮挡
  • 俯视:从上往下,适合监控场景
  • 仰视:从下往上,适合特殊视角

距离覆盖

  • 近景:目标占图片50%+,细节清晰
  • 中景:目标占图片20-50%,常见场景
  • 远景:目标占图片5-20%,小目标检测

2. 目标多样性规划

大小多样性

  • 大目标:占图片30-80%,容易检测
  • 中目标:占图片10-30%,标准检测
  • 小目标:占图片1-10%,需要高分辨率

状态多样性

  • 静止:目标静止,清晰可见
  • 运动:目标运动,可能有模糊
  • 部分遮挡:被其他对象遮挡20-50%
  • 严重遮挡:被遮挡50%+(可选,用于鲁棒性训练)

光照多样性

  • 明亮:充足光照,对比明显
  • 阴影:部分在阴影中,对比降低
  • 逆光:目标背光,轮廓清晰但细节模糊
  • 强光:过曝,细节丢失

3. 设备选择与设置

手机拍摄(推荐初学者):

优势

  • ✅ 方便携带,随时拍摄
  • ✅ 自动对焦,操作简单
  • ✅ 现代手机画质足够(1200万像素+)
  • ✅ 成本低,无需额外设备

设置建议

  • 分辨率:设置为最高分辨率(通常4K或更高)
  • 格式:使用JPG格式,平衡质量和文件大小
  • 对焦:确保目标清晰对焦
  • 稳定:使用三脚架或稳定器,避免抖动

相机拍摄(推荐专业项目):

优势

  • ✅ 画质更高,细节更丰富
  • ✅ 可控参数多(ISO、光圈、快门)
  • ✅ 适合专业项目

设置建议

  • ISO:尽量低(100-400),减少噪点
  • 光圈:f/5.6-f/8,平衡景深和画质
  • 快门:1/250s+,避免运动模糊
  • 白平衡:根据场景调整,保持色彩准确

无人机拍摄(适合大场景):

优势

  • ✅ 俯视角度,适合监控场景
  • ✅ 覆盖大范围,效率高
  • ✅ 视角独特,增加数据多样性

注意事项

  • 遵守飞行法规
  • 注意天气条件(风、雨)
  • 确保电池充足

4. 拍摄工作流程

准备阶段(1-2天):

  1. 制定拍摄计划

    • 列出所有需要覆盖的场景
    • 规划拍摄路线和时间
    • 准备设备(相机、存储卡、电池)
  2. 设备检查

    • 检查相机/手机电量
    • 检查存储空间(建议至少100GB)
    • 检查镜头清洁度

拍摄阶段(根据项目规模):

  1. 按计划拍摄

    • 严格按照场景覆盖计划
    • 每个场景至少拍摄50-100张
    • 记录拍摄信息(时间、地点、场景)
  2. 实时检查

    • 定期检查照片质量
    • 删除模糊、失焦的照片
    • 确保目标清晰可见
  3. 数据备份

    • 每天拍摄后立即备份
    • 使用多个存储设备
    • 避免数据丢失

整理阶段(拍摄完成后):

  1. 照片筛选

    • 删除模糊、失焦的照片
    • 删除重复的照片
    • 保留高质量照片
  2. 照片命名

    • 使用有意义的命名规则
    • 例如:scene_time_weather_001.jpg
    • 便于后续管理和标注
  3. 数据统计

    • 统计各类场景的照片数量
    • 检查场景覆盖是否完整
    • 补充缺失的场景

真实案例

案例1:自动驾驶道路场景

某自动驾驶公司需要收集道路场景数据。团队制定了详细的拍摄计划:

  • 时间:白天、夜晚、黄昏各拍摄1个月
  • 天气:晴天、雨天、阴天各拍摄2周
  • 地点:5个不同城市,覆盖城市道路、高速公路、乡村道路
  • 设备:8个车载摄像头,同时拍摄
  • 结果:3个月收集了50,000张高质量图片,覆盖了各种场景

案例2:工业质检产品拍摄

某工厂需要检测产品缺陷。团队使用工业相机:

  • 固定拍摄位置,确保一致性
  • 使用标准光源,减少光照变化
  • 每个产品拍摄多角度(正面、侧面、顶部)
  • 结果:1个月收集了20,000张产品图片,缺陷样本5,000张

拍摄检查清单

设备准备

  • 相机/手机电量充足
  • 存储空间足够(建议100GB+)
  • 镜头清洁,无污渍
  • 备用电池和存储卡

拍摄质量

  • 目标清晰,无模糊
  • 对焦准确,无失焦
  • 曝光正常,不过曝或欠曝
  • 构图合理,目标完整

场景覆盖

  • 时间覆盖完整(白天/夜晚)
  • 天气覆盖完整(晴天/雨天)
  • 角度覆盖完整(正面/侧面)
  • 距离覆盖完整(近景/远景)

数据管理

  • 照片命名规范
  • 数据及时备份
  • 拍摄信息记录完整

数据来源3:视频提取(高效方式)

优势

  • 从视频中提取帧,效率高
  • 可以覆盖连续动作
  • 场景自然

使用TjMakeBot提取

  1. 上传视频文件
  2. 设置提取帧率(如1fps)
  3. 自动提取关键帧
  4. 直接标注提取的帧

技巧

  • 选择关键帧:避免重复帧
  • 设置合适帧率:1-5fps通常足够
  • 处理多个视频:覆盖不同场景

数据来源4:其他来源(需谨慎)

注意事项

  • 遵守数据使用许可协议
  • 尊重知识产权和版权
  • 获得必要的授权或许可
  • 不要使用受版权保护的内容

数据要求检查清单

清晰度

  • 图片清晰,目标对象可见
  • 避免模糊、失焦的图片
  • 分辨率至少640×640

目标大小

  • 目标对象大小适中(建议占图片5%-50%)
  • 避免目标太小(< 1%)或太大(> 80%)
  • 小目标需要更高分辨率

场景多样性

  • 覆盖不同场景
  • 避免过拟合
  • 包含边界案例

目标完整性

  • 标注对象完整
  • 避免严重遮挡(> 50%遮挡)
  • 部分遮挡(< 50%)可以标注

1.3 数据预处理

数据预处理是确保数据质量的关键步骤,直接影响模型训练效果。

预处理流程

步骤1:数据清洗

删除低质量图片

检查项

  • 模糊图片:目标不清晰,无法识别
  • 失焦图片:焦点不在目标上
  • 过曝/欠曝:曝光严重异常
  • 重复图片:完全相同或高度相似
  • 无关图片:不包含目标对象

自动化清洗脚本

import cv2
import numpy as np
import os
from PIL import Image
import imagehash

def calculate_blur_score(image_path):
    """计算图片模糊度"""
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    laplacian_var = cv2.Laplacian(img, cv2.CV_64F).var()
    return laplacian_var

def find_duplicates(image_dir, threshold=5):
    """查找重复图片"""
    image_hashes = {}
    duplicates = []
    
    for filename in os.listdir(image_dir):
        if filename.endswith(('.jpg', '.png')):
            filepath = os.path.join(image_dir, filename)
            img_hash = imagehash.average_hash(Image.open(filepath))
            
            # 检查是否有相似的图片
            for existing_file, existing_hash in image_hashes.items():
                if img_hash - existing_hash < threshold:
                    duplicates.append((existing_file, filename))
                    break
            
            image_hashes[filename] = img_hash
    
    return duplicates

def clean_dataset(image_dir, blur_threshold=100):
    """清洗数据集"""
    cleaned_dir = os.path.join(image_dir, 'cleaned')
    os.makedirs(cleaned_dir, exist_ok=True)
    
    removed_count = 0
    
    for filename in os.listdir(image_dir):
        if filename.endswith(('.jpg', '.png')):
            filepath = os.path.join(image_dir, filename)
            
            # 检查模糊度
            blur_score = calculate_blur_score(filepath)
            if blur_score < blur_threshold:
                print(f"删除模糊图片: {filename} (模糊度: {blur_score:.2f})")
                removed_count += 1
                continue
            
            # 复制到清洗后的目录
            import shutil
            shutil.copy(filepath, os.path.join(cleaned_dir, filename))
    
    print(f"清洗完成,删除了 {removed_count} 张低质量图片")

# 使用
clean_dataset('./raw_images')

手动检查

  • 快速浏览所有图片
  • 标记明显有问题的图片
  • 批量删除

步骤2:统一格式

格式选择

格式 优势 劣势 推荐场景
JPG 文件小,加载快 有损压缩 大多数场景(推荐)
PNG 无损压缩,质量高 文件大 需要高质量的场景

转换脚本

from PIL import Image
import os

def convert_format(input_dir, output_dir, target_format='JPG', quality=95):
    """统一图片格式"""
    os.makedirs(output_dir, exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if filename.endswith(('.jpg', '.png', '.bmp', '.tiff')):
            input_path = os.path.join(input_dir, filename)
            output_filename = os.path.splitext(filename)[0] + f'.{target_format.lower()}'
            output_path = os.path.join(output_dir, output_filename)
            
            # 打开并转换
            img = Image.open(input_path)
            
            # 转换为RGB(如果是RGBA)
            if img.mode == 'RGBA':
                rgb_img = Image.new('RGB', img.size, (255, 255, 255))
                rgb_img.paste(img, mask=img.split()[3])
                img = rgb_img
            
            # 保存
            if target_format == 'JPG':
                img.save(output_path, 'JPEG', quality=quality)
            else:
                img.save(output_path, target_format)
            
            print(f"转换: {filename} -> {output_filename}")

# 使用
convert_format('./raw_images', './formatted_images', 'JPG', quality=95)

步骤3:统一尺寸

尺寸选择原则

YOLO输入尺寸

  • 640×640:标准尺寸,平衡速度和精度(推荐)
  • 416×416:快速检测,适合实时应用
  • 1280×1280:高精度检测,适合小目标

调整方法

方法1:保持宽高比缩放(推荐)

from PIL import Image

def resize_with_aspect_ratio(image_path, target_size=640, padding_color=(114, 114, 114)):
    """
    保持宽高比缩放,不足部分用灰色填充
    """
    img = Image.open(image_path)
    original_width, original_height = img.size
    
    # 计算缩放比例
    scale = min(target_size / original_width, target_size / original_height)
    new_width = int(original_width * scale)
    new_height = int(original_height * scale)
    
    # 缩放图片
    img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # 创建目标尺寸的画布
    img_padded = Image.new('RGB', (target_size, target_size), padding_color)
    
    # 计算居中位置
    x_offset = (target_size - new_width) // 2
    y_offset = (target_size - new_height) // 2
    
    # 粘贴缩放后的图片
    img_padded.paste(img_resized, (x_offset, y_offset))
    
    return img_padded

# 批量处理
def batch_resize(input_dir, output_dir, target_size=640):
    """批量调整尺寸"""
    os.makedirs(output_dir, exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if filename.endswith(('.jpg', '.png')):
            input_path = os.path.join(input_dir, filename)
            output_path = os.path.join(output_dir, filename)
            
            img_resized = resize_with_aspect_ratio(input_path, target_size)
            img_resized.save(output_path)
            print(f"调整尺寸: {filename}")

# 使用
batch_resize('./formatted_images', './resized_images', target_size=640)

方法2:直接拉伸(不推荐)

  • 会改变目标形状
  • 可能导致模型学习到错误的特征
  • 仅在目标形状不重要时使用

步骤4:数据增强(可选)

何时使用数据增强

  • ✅ 数据量不足时
  • ✅ 需要提高模型泛化能力时
  • ✅ 类别不平衡时

常用增强方法

1. 几何变换

  • 旋转:±15度,模拟不同角度
  • 翻转:水平翻转、垂直翻转
  • 缩放:0.8-1.2倍,模拟不同距离
  • 平移:±10%,模拟位置变化

2. 颜色变换

  • 亮度调整:±20%,模拟不同光照
  • 对比度调整:±20%,增强/减弱对比
  • 饱和度调整:±30%,模拟不同环境
  • 色调调整:±10度,模拟不同光源

3. 噪声添加

  • 高斯噪声:模拟传感器噪声
  • 椒盐噪声:模拟传输错误

增强脚本

from PIL import Image, ImageEnhance
import random
import os

def augment_image(image_path, output_dir, num_augmentations=3):
    """对单张图片进行增强"""
    img = Image.open(image_path)
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    
    for i in range(num_augmentations):
        # 随机旋转
        angle = random.uniform(-15, 15)
        img_rotated = img.rotate(angle, expand=False)
        
        # 随机翻转
        if random.random() > 0.5:
            img_rotated = img_rotated.transpose(Image.FLIP_LEFT_RIGHT)
        
        # 随机亮度调整
        enhancer = ImageEnhance.Brightness(img_rotated)
        img_rotated = enhancer.enhance(random.uniform(0.8, 1.2))
        
        # 随机对比度调整
        enhancer = ImageEnhance.Contrast(img_rotated)
        img_rotated = enhancer.enhance(random.uniform(0.8, 1.2))
        
        # 保存
        output_path = os.path.join(output_dir, f"{base_name}_aug_{i}.jpg")
        img_rotated.save(output_path)
        print(f"增强: {base_name}_aug_{i}.jpg")

def batch_augment(input_dir, output_dir, num_augmentations=3):
    """批量增强"""
    os.makedirs(output_dir, exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if filename.endswith(('.jpg', '.png')):
            input_path = os.path.join(input_dir, filename)
            augment_image(input_path, output_dir, num_augmentations)

# 使用
batch_augment('./resized_images', './augmented_images', num_augmentations=3)

注意:数据增强需要在标注之前进行,或者使用支持自动调整标注坐标的工具。

预处理检查清单

数据清洗

  • 删除模糊图片
  • 删除失焦图片
  • 删除重复图片
  • 删除无关图片

格式统一

  • 统一为JPG或PNG格式
  • 转换为RGB模式
  • 检查文件完整性

尺寸统一

  • 调整到目标尺寸(如640×640)
  • 保持宽高比(推荐)
  • 检查图片质量

数据增强(可选):

  • 确定增强方法
  • 应用增强
  • 检查增强效果

数据统计

  • 统计最终图片数量
  • 检查各类别分布
  • 验证数据质量

🎨 步骤 2:数据标注

2.1 选择标注工具

工具选择建议

不同工具有不同的特点:

  • 免费工具:适合预算有限的用户,功能可能相对简单
  • 付费工具:功能通常更全面,适合有预算的企业用户
  • 选择原则:根据项目需求、预算和技术能力选择

TjMakeBot的特点

  • ✅ 免费(基础功能)
  • ✅ AI 聊天式标注,显著提升效率
  • ✅ 支持批量处理
  • ✅ 在线即用,无需安装
  • ✅ 支持视频转帧

2.2 创建类别标签

在 TjMakeBot 中创建你的类别:

类别列表示例:
0: car (汽车)
1: person (行人)
2: bicycle (自行车)
3: motorcycle (摩托车)
4: bus (公交车)

命名规范

  • 使用英文小写
  • 避免空格和特殊字符
  • 类别名称要清晰明确

2.3 开始标注:两种方法详解

方法 1:AI 聊天式标注(强烈推荐)

适用场景

  • 批量标注(> 100张)
  • 标准场景(常见对象)
  • 快速原型开发
  • 预算有限的项目

完整流程

步骤1:上传图片(1分钟)

  • 批量上传所有图片
  • 建议先上传10-20张测试

步骤2:打开AI助手(5秒)

  • 点击"AI助手"按钮
  • 聊天面板打开

步骤3:输入指令(10秒)

基础指令:
"请标注所有汽车和行人"

高级指令:
"标注所有车辆,但排除摩托车"
"标注图片中心区域的所有目标"
"标注所有大于100像素的汽车"

步骤4:AI自动标注(自动)

  • AI理解指令
  • 自动识别目标
  • 生成标注结果

步骤5:检查与微调(5-10分钟/100张)

  • 快速浏览标注结果
  • 修正明显错误
  • 补充遗漏标注

步骤6:应用到全部(1秒)

  • 确认效果满意
  • 一键应用到所有图片

优势

  • ⚡ 速度快:1000张图片2-3小时完成
  • 🎯 准确率高:AI准确率通常>90%
  • 💰 成本低:免费工具,成本几乎为零
  • 📈 效率高:批量处理,大幅提升效率

真实案例

某学生项目需要标注2000张图片。使用AI聊天式标注,2天完成标注,准确率达到95%。如果用传统方式,需要2周时间。

方法 2:手动标注(适合复杂场景)

适用场景

  • 复杂场景(AI难以识别)
  • 特殊对象(AI未训练过的类别)
  • 高精度要求(需要像素级精度)
  • 小规模项目(< 100张)

完整流程

步骤1:选择图片(5秒)

  • 点击图片,打开标注界面

步骤2:选择类别(3秒)

  • 从类别列表中选择
  • 或创建新类别

步骤3:绘制边界框(10-30秒)

  • 鼠标拖拽绘制矩形框
  • 从左上角拖到右下角
  • 或使用快捷键

步骤4:调整位置和大小(10-20秒)

  • 拖拽边界框移动位置
  • 拖拽角点调整大小
  • 使用方向键微调

步骤5:保存标注(2秒)

  • 自动保存
  • 或手动保存

手动标注技巧

技巧1:使用快捷键

  • W:切换工具
  • Delete:删除选中标注
  • 方向键:微调位置
  • Ctrl+Z:撤销

技巧2:精确调整

  • 使用缩放功能放大图片
  • 使用十字标线精确定位
  • 多次微调达到最佳位置

技巧3:批量操作

  • 复制标注到下一张
  • 批量删除错误标注
  • 批量修改类别

优势

  • 🎯 精度高:可以精确到像素级
  • 🔧 灵活:可以处理任何场景
  • 📝 可控:完全控制标注过程

劣势

  • ⏱️ 速度慢:每张需要2-5分钟
  • 💰 成本高:需要大量人力
  • 😴 易疲劳:长时间标注容易出错

建议:结合使用AI辅助和手动标注,AI处理标准场景,手动处理复杂场景。

2.4 标注质量检查:确保数据质量

为什么质量检查如此重要?

一个真实案例:

某项目标注了5000张图片,训练模型后发现准确率只有70%。经过检查,发现标注数据中存在15%的错误。重新标注后,模型准确率提升到92%。

质量检查清单

1. 完整性检查(最重要)

  • ✅ 所有目标对象都已标注
  • ✅ 没有遗漏的对象
  • ✅ 部分遮挡的对象也已标注

检查方法

  • 逐张浏览,查找遗漏
  • 使用AI辅助检查(AI可以识别遗漏)
  • 抽样检查(每10张检查1张)

2. 准确性检查

  • ✅ 边界框精确覆盖目标
  • ✅ 边界框不包含过多背景(< 10%)
  • ✅ 边界框不遗漏目标部分

检查方法

  • 检查边界框是否紧贴目标边缘
  • 检查是否有明显偏差
  • 使用IoU指标评估

3. 类别准确性

  • ✅ 类别标签正确
  • ✅ 没有类别混淆
  • ✅ 边界情况处理正确

检查方法

  • 检查每个标注框的类别
  • 特别关注容易混淆的类别
  • 统一边界情况的处理

4. 一致性检查

  • ✅ 没有重复标注
  • ✅ 标注标准统一
  • ✅ 不同标注员标准一致

检查方法

  • 检查是否有重叠的标注框
  • 对比不同标注员的标注
  • 统计标注差异

质量指标标准

指标 最低标准 推荐标准 优秀标准
标注完整率 > 90% > 95% > 98%
边界框准确率 > 85% > 90% > 95%
类别准确率 > 95% > 98% > 99%
标注一致性 > 85% > 90% > 95%

质量检查工具

TjMakeBot内置质量检查

  • 自动检测遗漏标注
  • 自动检测重复标注
  • 自动检测边界框偏差
  • 生成质量报告

使用步骤

  1. 完成标注后,点击"质量检查"
  2. 系统自动分析标注质量
  3. 生成质量报告
  4. 根据报告修正问题

质量改进流程

第一轮检查(标注完成后):

  • 快速浏览所有图片
  • 发现明显错误
  • 修正错误标注

第二轮检查(修正后):

  • 抽样检查(20-30%)
  • 详细检查边界框
  • 检查类别准确性

第三轮检查(最终确认):

  • 专家审核
  • 性能测试
  • 最终确认

质量检查时间分配

  • 标注时间:70%
  • 质量检查:20%
  • 修正时间:10%

记住:质量检查的时间投入是值得的,可以避免后续的返工成本。

📦 步骤 3:数据格式转换

数据格式转换是将标注结果转换为YOLO训练所需格式的关键步骤。

3.1 导出 YOLO 格式

使用 TjMakeBot 导出

操作步骤

  1. 选择标注数据

    • 在TjMakeBot中打开标注项目
    • 选择所有已标注的图片
    • 或选择特定类别的图片
  2. 导出设置

    • 点击"导出"按钮
    • 选择"YOLO 格式"
    • 选择导出选项:
      • ✅ 包含图片
      • ✅ 包含标注文件
      • ✅ 保持目录结构
  3. 下载文件

    • 等待导出完成
    • 下载ZIP文件
    • 解压到本地目录

导出结果结构

dataset/
├── images/
│   ├── image001.jpg
│   ├── image002.jpg
│   └── ...
└── labels/
    ├── image001.txt
    ├── image002.txt
    └── ...

手动转换(从其他格式)

从VOC格式转换

import xml.etree.ElementTree as ET
import os

def voc_to_yolo(voc_xml_path, yolo_txt_path, img_width, img_height, class_mapping):
    """
    将VOC格式转换为YOLO格式
    """
    tree = ET.parse(voc_xml_path)
    root = tree.getroot()
    
    with open(yolo_txt_path, 'w') as f:
        for obj in root.findall('object'):
            # 获取类别
            class_name = obj.find('name').text
            class_id = class_mapping[class_name]
            
            # 获取边界框(VOC格式:xmin, ymin, xmax, ymax)
            bbox = obj.find('bndbox')
            xmin = float(bbox.find('xmin').text)
            ymin = float(bbox.find('ymin').text)
            xmax = float(bbox.find('xmax').text)
            ymax = float(bbox.find('ymax').text)
            
            # 转换为YOLO格式
            center_x = ((xmin + xmax) / 2) / img_width
            center_y = ((ymin + ymax) / 2) / img_height
            width = (xmax - xmin) / img_width
            height = (ymax - ymin) / img_height
            
            # 写入文件
            f.write(f"{class_id} {center_x} {center_y} {width} {height}\n")

# 使用
class_mapping = {'car': 0, 'person': 1, 'bicycle': 2}
voc_to_yolo('annotations/image001.xml', 'labels/image001.txt', 1920, 1080, class_mapping)

从COCO格式转换

import json
from PIL import Image

def coco_to_yolo(coco_json_path, output_dir, class_mapping):
    """
    将COCO格式转换为YOLO格式
    """
    with open(coco_json_path, 'r') as f:
        coco_data = json.load(f)
    
    # 创建输出目录
    os.makedirs(f'{output_dir}/labels', exist_ok=True)
    
    # 建立图片ID到文件名的映射
    img_id_to_info = {img['id']: img for img in coco_data['images']}
    
    # 按图片ID分组标注
    annotations_by_img = {}
    for ann in coco_data['annotations']:
        img_id = ann['image_id']
        if img_id not in annotations_by_img:
            annotations_by_img[img_id] = []
        annotations_by_img[img_id].append(ann)
    
    # 转换每个图片的标注
    for img_id, anns in annotations_by_img.items():
        img_info = img_id_to_info[img_id]
        img_width = img_info['width']
        img_height = img_info['height']
        
        # 创建YOLO格式文件
        label_file = f"{output_dir}/labels/{img_info['file_name'].replace('.jpg', '.txt')}"
        with open(label_file, 'w') as f:
            for ann in anns:
                category_id = ann['category_id']
                class_name = next(cat['name'] for cat in coco_data['categories'] if cat['id'] == category_id)
                class_id = class_mapping.get(class_name, -1)
                
                if class_id == -1:
                    continue  # 跳过未映射的类别
                
                # COCO格式:x, y, width, height(绝对坐标)
                bbox = ann['bbox']
                x, y, w, h = bbox
                
                # 转换为YOLO格式(归一化)
                center_x = (x + w / 2) / img_width
                center_y = (y + h / 2) / img_height
                norm_w = w / img_width
                norm_h = h / img_height
                
                f.write(f"{class_id} {center_x} {center_y} {norm_w} {norm_h}\n")

# 使用
class_mapping = {'car': 0, 'person': 1, 'bicycle': 2}
coco_to_yolo('annotations/instances_train2017.json', './yolo_dataset', class_mapping)

3.2 验证标注文件

验证标注文件是确保数据质量的关键步骤,可以避免训练时的错误。

验证脚本

完整验证脚本

import os
from PIL import Image

def validate_yolo_dataset(dataset_dir):
    """
    验证YOLO数据集
    """
    images_dir = os.path.join(dataset_dir, 'images')
    labels_dir = os.path.join(dataset_dir, 'labels')
    
    errors = []
    warnings = []
    
    # 获取所有图片文件
    image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
    
    for img_file in image_files:
        img_path = os.path.join(images_dir, img_file)
        label_file = os.path.splitext(img_file)[0] + '.txt'
        label_path = os.path.join(labels_dir, label_file)
        
        # 检查1:标注文件是否存在
        if not os.path.exists(label_path):
            errors.append(f"缺失标注文件: {label_file}")
            continue
        
        # 检查2:图片是否可以打开
        try:
            img = Image.open(img_path)
            img_width, img_height = img.size
        except Exception as e:
            errors.append(f"无法打开图片: {img_file} - {str(e)}")
            continue
        
        # 检查3:读取标注文件
        try:
            with open(label_path, 'r') as f:
                lines = f.readlines()
        except Exception as e:
            errors.append(f"无法读取标注文件: {label_file} - {str(e)}")
            continue
        
        # 检查4:验证每行格式
        for line_num, line in enumerate(lines, 1):
            line = line.strip()
            if not line:
                continue
            
            parts = line.split()
            
            # 检查格式:应该有5个数字
            if len(parts) != 5:
                errors.append(f"{label_file}:{line_num} - 格式错误,应该有5个数字,实际有{len(parts)}个")
                continue
            
            try:
                class_id = int(parts[0])
                center_x = float(parts[1])
                center_y = float(parts[2])
                width = float(parts[3])
                height = float(parts[4])
            except ValueError as e:
                errors.append(f"{label_file}:{line_num} - 无法解析数字: {str(e)}")
                continue
            
            # 检查5:类别ID是否有效
            if class_id < 0:
                errors.append(f"{label_file}:{line_num} - 类别ID不能为负数: {class_id}")
            
            # 检查6:坐标是否在0-1范围内
            if not (0 <= center_x <= 1):
                errors.append(f"{label_file}:{line_num} - center_x超出范围: {center_x}")
            if not (0 <= center_y <= 1):
                errors.append(f"{label_file}:{line_num} - center_y超出范围: {center_y}")
            if not (0 < width <= 1):
                errors.append(f"{label_file}:{line_num} - width超出范围: {width}")
            if not (0 < height <= 1):
                errors.append(f"{label_file}:{line_num} - height超出范围: {height}")
            
            # 检查7:边界框是否超出图片范围
            x_min = center_x - width / 2
            x_max = center_x + width / 2
            y_min = center_y - height / 2
            y_max = center_y + height / 2
            
            if x_min < 0 or x_max > 1 or y_min < 0 or y_max > 1:
                warnings.append(f"{label_file}:{line_num} - 边界框超出图片范围")
            
            # 检查8:边界框是否太小
            if width < 0.01 or height < 0.01:
                warnings.append(f"{label_file}:{line_num} - 边界框太小(可能标注错误)")
            
            # 检查9:边界框是否太大
            if width > 0.95 or height > 0.95:
                warnings.append(f"{label_file}:{line_num} - 边界框太大(可能标注错误)")
    
    # 输出结果
    print("=" * 50)
    print("验证结果")
    print("=" * 50)
    
    if errors:
        print(f"\n❌ 发现 {len(errors)} 个错误:")
        for error in errors[:10]:  # 只显示前10个
            print(f"  - {error}")
        if len(errors) > 10:
            print(f"  ... 还有 {len(errors) - 10} 个错误")
    else:
        print("\n✅ 未发现错误")
    
    if warnings:
        print(f"\n⚠️  发现 {len(warnings)} 个警告:")
        for warning in warnings[:10]:  # 只显示前10个
            print(f"  - {warning}")
        if len(warnings) > 10:
            print(f"  ... 还有 {len(warnings) - 10} 个警告")
    else:
        print("\n✅ 未发现警告")
    
    return len(errors) == 0

# 使用
is_valid = validate_yolo_dataset('./dataset')
if is_valid:
    print("\n✅ 数据集验证通过,可以开始训练")
else:
    print("\n❌ 数据集验证失败,请修复错误后再训练")

验证检查清单

文件完整性

  • 每个图片都有对应的标注文件
  • 每个标注文件都有对应的图片
  • 文件名匹配(除了扩展名)

格式正确性

  • 标注文件每行有5个数字
  • 所有数字都是有效的浮点数
  • 类别ID是整数

坐标有效性

  • 所有坐标值在0-1范围内
  • 边界框不超出图片范围
  • 边界框大小合理(不太小也不太大)

数据一致性

  • 类别ID连续(0, 1, 2, ...)
  • 没有重复标注
  • 标注与图片内容匹配

3.3 创建数据集配置文件

数据集配置文件是YOLO训练必需的,定义了数据集的路径、类别等信息。

YOLOv8 配置文件

标准格式dataset.yaml):

# 数据集路径(相对于此文件或绝对路径)
path: /path/to/dataset  # 数据集根目录

# 训练集和验证集路径(相对于path)
train: images/train  # 训练集图片目录
val: images/val      # 验证集图片目录
test: images/test    # 测试集图片目录(可选)

# 类别数量
nc: 5

# 类别名称(必须与类别ID对应)
names:
  0: car
  1: person
  2: bicycle
  3: motorcycle
  4: bus

YOLOv5 配置文件

标准格式dataset.yaml):

# 数据集路径
train: /path/to/dataset/images/train
val: /path/to/dataset/images/val
test: /path/to/dataset/images/test  # 可选

# 类别数量
nc: 5

# 类别名称
names: ['car', 'person', 'bicycle', 'motorcycle', 'bus']

配置文件生成脚本

自动生成脚本

import os
import yaml

def create_dataset_yaml(dataset_dir, class_names, output_file='dataset.yaml', yolo_version='v8'):
    """
    自动生成数据集配置文件
    """
    # 检查目录结构
    images_dir = os.path.join(dataset_dir, 'images')
    labels_dir = os.path.join(dataset_dir, 'labels')
    
    # 检查是否有train/val/test子目录
    has_splits = os.path.exists(os.path.join(images_dir, 'train'))
    
    if yolo_version == 'v8':
        if has_splits:
            config = {
                'path': os.path.abspath(dataset_dir),
                'train': 'images/train',
                'val': 'images/val',
                'nc': len(class_names),
                'names': {i: name for i, name in enumerate(class_names)}
            }
            
            # 如果有测试集
            if os.path.exists(os.path.join(images_dir, 'test')):
                config['test'] = 'images/test'
        else:
            # 如果没有划分,使用images目录
            config = {
                'path': os.path.abspath(dataset_dir),
                'train': 'images',
                'val': 'images',  # 注意:实际使用时需要划分
                'nc': len(class_names),
                'names': {i: name for i, name in enumerate(class_names)}
            }
    else:  # YOLOv5
        if has_splits:
            config = {
                'train': os.path.join(os.path.abspath(dataset_dir), 'images', 'train'),
                'val': os.path.join(os.path.abspath(dataset_dir), 'images', 'val'),
                'nc': len(class_names),
                'names': class_names
            }
            
            if os.path.exists(os.path.join(images_dir, 'test')):
                config['test'] = os.path.join(os.path.abspath(dataset_dir), 'images', 'test')
        else:
            config = {
                'train': os.path.join(os.path.abspath(dataset_dir), 'images'),
                'val': os.path.join(os.path.abspath(dataset_dir), 'images'),
                'nc': len(class_names),
                'names': class_names
            }
    
    # 保存配置文件
    with open(output_file, 'w', encoding='utf-8') as f:
        yaml.dump(config, f, allow_unicode=True, default_flow_style=False)
    
    print(f"✅ 配置文件已生成: {output_file}")
    print("\n配置文件内容:")
    print("=" * 50)
    with open(output_file, 'r', encoding='utf-8') as f:
        print(f.read())
    print("=" * 50)

# 使用示例
class_names = ['car', 'person', 'bicycle', 'motorcycle', 'bus']
create_dataset_yaml('./dataset', class_names, 'dataset.yaml', yolo_version='v8')

配置文件验证

验证脚本

import yaml
import os

def validate_dataset_yaml(yaml_file, dataset_dir):
    """
    验证数据集配置文件
    """
    with open(yaml_file, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    errors = []
    
    # 检查必需的字段
    required_fields = ['nc', 'names']
    for field in required_fields:
        if field not in config:
            errors.append(f"缺少必需字段: {field}")
    
    # 检查类别数量
    if 'nc' in config and 'names' in config:
        if isinstance(config['names'], dict):
            num_names = len(config['names'])
        else:
            num_names = len(config['names'])
        
        if config['nc'] != num_names:
            errors.append(f"类别数量不匹配: nc={config['nc']}, names数量={num_names}")
    
    # 检查路径
    if 'path' in config:
        path = config['path']
        if not os.path.isabs(path):
            path = os.path.join(os.path.dirname(yaml_file), path)
        
        if not os.path.exists(path):
            errors.append(f"数据集路径不存在: {path}")
    
    # 检查训练集和验证集路径
    for split in ['train', 'val']:
        if split in config:
            split_path = config[split]
            if 'path' in config:
                full_path = os.path.join(config['path'], split_path)
            else:
                full_path = split_path
            
            if not os.path.exists(full_path):
                errors.append(f"{split}路径不存在: {full_path}")
    
    if errors:
        print("❌ 配置文件验证失败:")
        for error in errors:
            print(f"  - {error}")
        return False
    else:
        print("✅ 配置文件验证通过")
        return True

# 使用
validate_dataset_yaml('dataset.yaml', './dataset')

配置文件检查清单

基本配置

  • 类别数量(nc)正确
  • 类别名称(names)完整
  • 类别ID从0开始连续

路径配置

  • 数据集路径(path)正确
  • 训练集路径(train)存在
  • 验证集路径(val)存在
  • 测试集路径(test)存在(如果使用)

格式正确

  • YAML格式正确
  • 编码为UTF-8
  • 缩进正确(使用空格,不是Tab)

🔄 步骤 4:数据集划分

数据集划分是训练前的关键步骤,合理的划分可以确保模型评估的准确性。

4.1 划分策略

划分比例选择

标准划分比例

数据集规模 训练集 验证集 测试集 说明
小数据集(< 1000张) 70% 15% 15% 确保有足够训练数据
中等数据集(1000-10000张) 75% 12.5% 12.5% 平衡训练和评估
大数据集(> 10000张) 80% 10% 10% 训练数据充足,验证集足够

为什么需要三个集合?

  1. 训练集(Train)

    • 用于模型训练
    • 模型学习数据特征
    • 通常占70-80%
  2. 验证集(Validation)

    • 用于调整超参数
    • 监控训练过程
    • 防止过拟合
    • 通常占10-15%
  3. 测试集(Test)

    • 用于最终评估
    • 不参与训练和调参
    • 反映模型真实性能
    • 通常占10-15%

划分原则

1. 随机划分(基础方法)

适用场景

  • 数据场景相似
  • 无时间序列关系
  • 无场景相关性

方法

  • 随机打乱所有数据
  • 按比例划分
  • 保证各类别分布一致

2. 分层划分(推荐)

适用场景

  • 类别不平衡
  • 需要保证各类别比例一致

方法

  • 按类别分别划分
  • 每个类别都按相同比例划分
  • 保证训练集、验证集、测试集的类别分布一致

3. 场景划分(高级方法)

适用场景

  • 不同场景的数据
  • 需要测试泛化能力
  • 避免数据泄露

方法

  • 按场景分组
  • 同一场景的数据在同一集合
  • 避免训练集和测试集场景重叠

真实案例

某自动驾驶项目有5个城市的道路数据。如果随机划分,可能导致训练集和测试集都包含同一城市的数据,测试结果会过于乐观。正确的做法是按城市划分,训练集用3个城市,验证集用1个城市,测试集用1个城市。

类别平衡检查

检查脚本

import os
from collections import Counter

def check_class_balance(dataset_dir, splits=['train', 'val', 'test']):
    """
    检查各类别的数据分布
    """
    results = {}
    
    for split in splits:
        labels_dir = os.path.join(dataset_dir, 'labels', split)
        if not os.path.exists(labels_dir):
            continue
        
        class_counts = Counter()
        total_objects = 0
        
        for label_file in os.listdir(labels_dir):
            if label_file.endswith('.txt'):
                with open(os.path.join(labels_dir, label_file), 'r') as f:
                    for line in f:
                        if line.strip():
                            class_id = int(line.split()[0])
                            class_counts[class_id] += 1
                            total_objects += 1
        
        results[split] = {
            'class_counts': dict(class_counts),
            'total_objects': total_objects,
            'num_images': len([f for f in os.listdir(labels_dir) if f.endswith('.txt')])
        }
    
    # 打印结果
    print("=" * 60)
    print("类别分布统计")
    print("=" * 60)
    
    for split, data in results.items():
        print(f"\n{split.upper()}集:")
        print(f"  图片数量: {data['num_images']}")
        print(f"  目标总数: {data['total_objects']}")
        print(f"  类别分布:")
        
        for class_id in sorted(data['class_counts'].keys()):
            count = data['class_counts'][class_id]
            percentage = count / data['total_objects'] * 100
            print(f"    类别{class_id}: {count} ({percentage:.1f}%)")
    
    # 检查平衡性
    print("\n" + "=" * 60)
    print("平衡性检查")
    print("=" * 60)
    
    if 'train' in results:
        train_counts = results['train']['class_counts']
        max_count = max(train_counts.values())
        min_count = min(train_counts.values())
        imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
        
        print(f"训练集类别不平衡比例: {imbalance_ratio:.2f}")
        if imbalance_ratio > 10:
            print("⚠️  警告:类别严重不平衡,建议平衡数据")
        elif imbalance_ratio > 5:
            print("⚠️  注意:类别存在不平衡,建议考虑平衡")
        else:
            print("✅ 类别分布相对平衡")

# 使用
check_class_balance('./dataset')

4.2 使用脚本划分

基础划分脚本

简单随机划分

import os
import shutil
import random

def split_dataset_simple(source_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    简单随机划分数据集
    """
    # 设置随机种子,确保可重复
    random.seed(seed)
    
    images_dir = os.path.join(source_dir, 'images')
    labels_dir = os.path.join(source_dir, 'labels')
    
    # 获取所有图片
    images = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
    random.shuffle(images)
    
    # 计算划分点
    total = len(images)
    train_end = int(total * train_ratio)
    val_end = train_end + int(total * val_ratio)
    
    # 划分
    train_images = images[:train_end]
    val_images = images[train_end:val_end]
    test_images = images[val_end:]
    
    print(f"总图片数: {total}")
    print(f"训练集: {len(train_images)} ({len(train_images)/total*100:.1f}%)")
    print(f"验证集: {len(val_images)} ({len(val_images)/total*100:.1f}%)")
    print(f"测试集: {len(test_images)} ({len(test_images)/total*100:.1f}%)")
    
    # 复制文件
    for split, img_list in [('train', train_images), 
                            ('val', val_images), 
                            ('test', test_images)]:
        split_images_dir = os.path.join(source_dir, 'images', split)
        split_labels_dir = os.path.join(source_dir, 'labels', split)
        
        os.makedirs(split_images_dir, exist_ok=True)
        os.makedirs(split_labels_dir, exist_ok=True)
        
        for img in img_list:
            # 复制图片
            src_img = os.path.join(images_dir, img)
            dst_img = os.path.join(split_images_dir, img)
            shutil.copy(src_img, dst_img)
            
            # 复制标注
            label_name = os.path.splitext(img)[0] + '.txt'
            src_label = os.path.join(labels_dir, label_name)
            dst_label = os.path.join(split_labels_dir, label_name)
            
            if os.path.exists(src_label):
                shutil.copy(src_label, dst_label)
            else:
                print(f"⚠️  警告:标注文件不存在: {label_name}")
    
    print("\n✅ 数据集划分完成")

# 使用
split_dataset_simple('./dataset', train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

分层划分脚本(推荐)

按类别分层划分

import os
import shutil
import random
from collections import defaultdict

def split_dataset_stratified(source_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    分层划分数据集(按类别)
    """
    random.seed(seed)
    
    images_dir = os.path.join(source_dir, 'images')
    labels_dir = os.path.join(source_dir, 'labels')
    
    # 按类别分组图片
    images_by_class = defaultdict(list)
    
    for img_file in os.listdir(images_dir):
        if img_file.endswith(('.jpg', '.png')):
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_path = os.path.join(labels_dir, label_file)
            
            if os.path.exists(label_path):
                # 读取标注文件,获取类别
                with open(label_path, 'r') as f:
                    classes = set()
                    for line in f:
                        if line.strip():
                            class_id = int(line.split()[0])
                            classes.add(class_id)
                    
                    # 如果图片包含多个类别,使用主要类别
                    if classes:
                        main_class = max(classes, key=lambda c: sum(1 for line in open(label_path) if line.strip() and int(line.split()[0]) == c))
                        images_by_class[main_class].append(img_file)
    
    # 对每个类别分别划分
    train_images = []
    val_images = []
    test_images = []
    
    for class_id, images in images_by_class.items():
        random.shuffle(images)
        
        total = len(images)
        train_end = int(total * train_ratio)
        val_end = train_end + int(total * val_ratio)
        
        train_images.extend(images[:train_end])
        val_images.extend(images[train_end:val_end])
        test_images.extend(images[val_end:])
        
        print(f"类别{class_id}: 总数={total}, 训练={train_end}, 验证={val_end-train_end}, 测试={total-val_end}")
    
    # 打乱最终列表
    random.shuffle(train_images)
    random.shuffle(val_images)
    random.shuffle(test_images)
    
    print(f"\n总划分结果:")
    print(f"训练集: {len(train_images)}")
    print(f"验证集: {len(val_images)}")
    print(f"测试集: {len(test_images)}")
    
    # 复制文件
    for split, img_list in [('train', train_images), 
                            ('val', val_images), 
                            ('test', test_images)]:
        split_images_dir = os.path.join(source_dir, 'images', split)
        split_labels_dir = os.path.join(source_dir, 'labels', split)
        
        os.makedirs(split_images_dir, exist_ok=True)
        os.makedirs(split_labels_dir, exist_ok=True)
        
        for img in img_list:
            # 复制图片
            shutil.copy(os.path.join(images_dir, img), 
                       os.path.join(split_images_dir, img))
            
            # 复制标注
            label_name = os.path.splitext(img)[0] + '.txt'
            src_label = os.path.join(labels_dir, label_name)
            dst_label = os.path.join(split_labels_dir, label_name)
            
            if os.path.exists(src_label):
                shutil.copy(src_label, dst_label)
    
    print("\n✅ 分层划分完成")

# 使用
split_dataset_stratified('./dataset', train_ratio=0.7, val_ratio=0.15, test_ratio=0.15)

划分后验证

验证脚本

def verify_split(dataset_dir):
    """
    验证数据集划分结果
    """
    splits = ['train', 'val', 'test']
    
    for split in splits:
        images_dir = os.path.join(dataset_dir, 'images', split)
        labels_dir = os.path.join(dataset_dir, 'labels', split)
        
        if not os.path.exists(images_dir):
            print(f"⚠️  {split}集图片目录不存在")
            continue
        
        images = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
        labels = [f for f in os.listdir(labels_dir) if f.endswith('.txt')]
        
        # 检查图片和标注是否匹配
        missing_labels = []
        for img in images:
            label_name = os.path.splitext(img)[0] + '.txt'
            if label_name not in labels:
                missing_labels.append(label_name)
        
        if missing_labels:
            print(f"⚠️  {split}集有{len(missing_labels)}个图片缺少标注文件")
        else:
            print(f"✅ {split}集: {len(images)}张图片, {len(labels)}个标注文件, 全部匹配")

# 使用
verify_split('./dataset')

划分检查清单

划分前准备

  • 所有图片已标注
  • 标注文件已验证
  • 数据已清洗

划分过程

  • 使用随机种子确保可重复
  • 按类别分层划分(推荐)
  • 保持类别分布一致

划分后验证

  • 图片和标注文件匹配
  • 各类别分布检查
  • 划分比例符合预期

目录结构

  • 创建train/val/test子目录
  • 图片和标注文件正确复制
  • 目录结构清晰

🚀 步骤 5:模型训练

模型训练是将标注数据转化为可用模型的过程,需要合理配置参数和监控训练过程。

5.1 安装 YOLO 环境

YOLOv8 安装(推荐)

为什么选择YOLOv8?

  • ✅ 最新版本,性能最优
  • ✅ 安装简单,一行命令
  • ✅ API友好,易于使用
  • ✅ 文档完善,社区活跃

安装步骤

1. 基础安装

# 安装ultralytics(包含YOLOv8)
pip install ultralytics

# 验证安装
python -c "from ultralytics import YOLO; print('YOLOv8安装成功')"

2. GPU支持(可选,但强烈推荐)

# 检查CUDA是否可用
python -c "import torch; print(f'CUDA可用: {torch.cuda.is_available()}')"

# 如果CUDA不可用,安装CPU版本
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu

3. 依赖检查

# 检查关键依赖
pip list | grep -E "torch|ultralytics|opencv|pillow"

环境要求

  • Python 3.8+
  • PyTorch 1.8+
  • CUDA 11.0+(GPU训练,可选)

YOLOv5 安装(备选)

安装步骤

# 克隆仓库
git clone https://github.com/ultralytics/yolov5
cd yolov5

# 安装依赖
pip install -r requirements.txt

# 验证安装
python detect.py --help

依赖要求

  • Python 3.7+
  • PyTorch 1.7+
  • 其他依赖见requirements.txt

5.2 训练配置

YOLOv8 训练配置详解

完整训练脚本

from ultralytics import YOLO
import torch

# 检查设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")

# 加载预训练模型
# 模型选择:
# - yolov8n.pt: nano(最小,最快)
# - yolov8s.pt: small(小,快速)
# - yolov8m.pt: medium(中等,平衡)
# - yolov8l.pt: large(大,高精度)
# - yolov8x.pt: xlarge(最大,最高精度)
model = YOLO('yolov8n.pt')  # 根据需求选择

# 训练配置
results = model.train(
    # 数据集配置
    data='dataset.yaml',      # 数据集配置文件路径
    
    # 训练参数
    epochs=100,               # 训练轮数(建议:100-300)
    imgsz=640,                # 输入图片尺寸(640/416/1280)
    batch=16,                 # 批次大小(根据GPU内存调整)
    device=device,            # 设备('cuda'/'cpu'/'0,1'多GPU)
    
    # 优化器参数
    lr0=0.01,                 # 初始学习率
    lrf=0.01,                 # 最终学习率(lr0 * lrf)
    momentum=0.937,           # 动量
    weight_decay=0.0005,     # 权重衰减
    
    # 数据增强
    hsv_h=0.015,             # 色调增强
    hsv_s=0.7,               # 饱和度增强
    hsv_v=0.4,               # 明度增强
    degrees=0.0,             # 旋转角度
    translate=0.1,           # 平移
    scale=0.5,               # 缩放
    flipud=0.0,             # 上下翻转概率
    fliplr=0.5,             # 左右翻转概率
    mosaic=1.0,             # Mosaic增强概率
    mixup=0.0,              # MixUp增强概率
    
    # 训练设置
    patience=50,             # 早停耐心值(验证集不提升的轮数)
    save=True,               # 保存检查点
    save_period=10,          # 每N轮保存一次
    val=True,                # 训练时验证
    plots=True,              # 生成训练曲线图
    
    # 项目设置
    project='runs/detect',    # 项目目录
    name='my_model',         # 实验名称
    exist_ok=True,           # 允许覆盖已存在的实验
    pretrained=True,         # 使用预训练权重
    optimizer='SGD',         # 优化器(SGD/Adam/AdamW)
    verbose=True,            # 详细输出
    seed=0,                  # 随机种子
    deterministic=True,      # 确定性训练
    single_cls=False,        # 单类别模式
    rect=False,              # 矩形训练
    cos_lr=False,            # 余弦学习率调度
    close_mosaic=10,         # 最后N轮关闭Mosaic
    resume=False,            # 恢复训练
    amp=True,                # 自动混合精度
    fraction=1.0,            # 使用数据集的比例
    profile=False,           # 性能分析
    freeze=None,             # 冻结层(如:freeze=10冻结前10层)
)

# 训练完成后
print("训练完成!")
print(f"最佳模型保存在: {results.save_dir}")

关键参数详解

1. 模型选择

模型 参数量 速度 精度 适用场景
yolov8n 3.2M 最快 较低 实时检测,边缘设备
yolov8s 11.2M 中等 平衡速度和精度
yolov8m 25.9M 中等 较高 生产环境(推荐)
yolov8l 43.7M 较慢 高精度要求
yolov8x 68.2M 最慢 最高 研究,最高精度

选择建议

  • 初学者:yolov8n(快速验证)
  • 生产环境:yolov8m(平衡)
  • 高精度要求:yolov8l或yolov8x

2. 批次大小(batch)

GPU内存与批次大小

GPU内存 推荐批次大小(640×640)
4GB 4-8
6GB 8-12
8GB 12-16
12GB 16-24
16GB+ 24-32

调整方法

  • 如果内存不足,减小batch或imgsz
  • 如果内存充足,增大batch可以提升训练稳定性

3. 学习率(lr0)

学习率选择

  • 默认:0.01(SGD优化器)
  • 小数据集:0.001-0.005
  • 大数据集:0.01-0.02
  • 微调:0.0001-0.001

学习率调度

  • 余弦退火:cos_lr=True,学习率按余弦曲线下降
  • 线性衰减:默认,学习率线性下降

4. 训练轮数(epochs)

轮数建议

  • 小数据集(< 1000张):200-300轮
  • 中等数据集(1000-10000张):100-200轮
  • 大数据集(> 10000张):50-100轮

早停机制

  • patience=50:验证集性能50轮不提升则停止
  • 避免过拟合,节省训练时间

YOLOv5 训练配置

训练脚本

import torch
from pathlib import Path

# 设置路径
data_yaml = 'dataset.yaml'
weights = 'yolov5s.pt'  # 预训练权重
epochs = 100
batch_size = 16
img_size = 640
device = '0' if torch.cuda.is_available() else 'cpu'

# 训练命令(通过命令行)
# python train.py --data dataset.yaml --weights yolov5s.pt --epochs 100 --batch-size 16 --img 640 --device 0

5.3 训练监控

关键指标说明

1. mAP(Mean Average Precision)

mAP50

  • IoU阈值=0.5时的平均精度
  • 衡量模型整体性能
  • 目标:> 0.5(50%)

mAP50-95

  • IoU阈值从0.5到0.95的平均精度
  • 更严格的评估标准
  • 目标:> 0.3(30%)

2. Precision(精确率)

  • 预测为正例中真正为正例的比例
  • 衡量误检率
  • 目标:> 0.8(80%)

3. Recall(召回率)

  • 真正例中被正确预测的比例
  • 衡量漏检率
  • 目标:> 0.8(80%)

4. Loss(损失)

训练损失(train/box_loss)

  • 训练集上的边界框损失
  • 应该持续下降

验证损失(val/box_loss)

  • 验证集上的边界框损失
  • 应该下降,如果上升说明过拟合

训练过程监控

实时监控

# 训练过程中会自动生成:
# - 训练曲线图(results.png)
# - 混淆矩阵(confusion_matrix.png)
# - 验证结果(val_batch*.jpg)
# - 训练日志(results.csv)

查看训练日志

import pandas as pd
import matplotlib.pyplot as plt

# 读取训练日志
df = pd.read_csv('runs/detect/my_model/results.csv')

# 绘制训练曲线
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
plt.plot(df['epoch'], df['train/box_loss'], label='Train Loss')
plt.plot(df['epoch'], df['val/box_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')

plt.subplot(1, 3, 2)
plt.plot(df['epoch'], df['metrics/mAP50(B)'], label='mAP50')
plt.xlabel('Epoch')
plt.ylabel('mAP50')
plt.legend()
plt.title('mAP50 Curve')

plt.subplot(1, 3, 3)
plt.plot(df['epoch'], df['metrics/precision(B)'], label='Precision')
plt.plot(df['epoch'], df['metrics/recall(B)'], label='Recall')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.legend()
plt.title('Precision & Recall')

plt.tight_layout()
plt.savefig('training_curves.png')
plt.show()

训练技巧与最佳实践

1. 学习率调整策略

预热(Warm-up)

  • 前几个epoch使用较小的学习率
  • 帮助模型稳定训练
  • YOLOv8默认支持

学习率衰减

  • 使用余弦退火:cos_lr=True
  • 或线性衰减:默认

2. 数据增强策略

基础增强(默认开启):

  • 左右翻转:fliplr=0.5
  • 颜色增强:hsv_h/s/v
  • Mosaic:mosaic=1.0

高级增强(可选):

  • MixUp:mixup=0.15(小数据集)
  • 旋转:degrees=10(如果目标方向不重要)

3. 早停机制

设置

patience=50  # 验证集性能50轮不提升则停止

好处

  • 避免过拟合
  • 节省训练时间
  • 自动选择最佳模型

4. 模型检查点

自动保存

  • 每个epoch自动保存最佳模型
  • 保存在:runs/detect/my_model/weights/best.pt

手动保存

# 训练过程中可以随时保存
model.save('my_checkpoint.pt')

恢复训练

# 从检查点恢复训练
model = YOLO('runs/detect/my_model/weights/last.pt')
model.train(resume=True)

训练问题诊断

问题1:Loss不下降

可能原因

  • 学习率过大或过小
  • 数据质量差
  • 模型选择不当

解决方案

  • 调整学习率(尝试0.001-0.01)
  • 检查数据质量
  • 尝试更大的模型

问题2:过拟合(训练Loss下降,验证Loss上升)

可能原因

  • 数据量不足
  • 模型太大
  • 数据增强不足

解决方案

  • 增加数据量
  • 使用更小的模型
  • 增强数据增强
  • 使用dropout或正则化

问题3:训练很慢

可能原因

  • 使用CPU训练
  • 批次大小太小
  • 图片尺寸太大

解决方案

  • 使用GPU训练
  • 增大批次大小
  • 减小图片尺寸(如640→416)

训练检查清单

训练前准备

  • 数据集已划分(train/val/test)
  • 数据集配置文件(dataset.yaml)正确
  • 环境已安装(YOLOv8/YOLOv5)
  • GPU可用(如果使用GPU)

训练配置

  • 模型大小选择合适
  • 批次大小根据GPU内存设置
  • 学习率设置合理
  • 训练轮数足够

训练监控

  • 实时查看训练日志
  • 监控Loss曲线
  • 监控mAP曲线
  • 检查验证集性能

训练优化

  • 使用早停机制
  • 保存检查点
  • 调整超参数
  • 分析训练曲线

📊 步骤 6:模型评估与优化

模型评估是验证模型性能的关键步骤,优化是提升模型性能的持续过程。

6.1 评估模型

基础评估

YOLOv8 评估脚本

from ultralytics import YOLO

# 加载训练好的模型
model = YOLO('runs/detect/my_model/weights/best.pt')

# 在验证集上评估
metrics = model.val(data='dataset.yaml', split='val')

# 打印关键指标
print("=" * 50)
print("模型评估结果")
print("=" * 50)
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
print(f"Precision: {metrics.box.mp:.4f}")
print(f"Recall: {metrics.box.mr:.4f}")
print("=" * 50)

# 在测试集上评估(如果存在)
if os.path.exists('dataset/images/test'):
    test_metrics = model.val(data='dataset.yaml', split='test')
    print("\n测试集评估结果:")
    print(f"mAP50: {test_metrics.box.map50:.4f}")
    print(f"mAP50-95: {test_metrics.box.map:.4f}")

详细评估指标

1. 按类别评估

# 获取每个类别的详细指标
for i, class_name in enumerate(model.names.values()):
    print(f"\n类别 {i} ({class_name}):")
    print(f"  Precision: {metrics.box.p[i]:.4f}")
    print(f"  Recall: {metrics.box.r[i]:.4f}")
    print(f"  mAP50: {metrics.box.ap50[i]:.4f}")
    print(f"  mAP50-95: {metrics.box.ap[i]:.4f}")

2. 混淆矩阵分析

# 查看混淆矩阵(自动生成在results目录)
# 文件位置:runs/detect/my_model/confusion_matrix.png
# 分析:
# - 对角线:正确分类
# - 非对角线:错误分类
# - 找出容易混淆的类别对

3. 可视化检测结果

# 在测试图片上可视化检测结果
results = model('dataset/images/test', save=True, conf=0.25)

# 查看检测结果
for result in results:
    # 获取检测框
    boxes = result.boxes
    # 获取类别
    classes = boxes.cls
    # 获取置信度
    confidences = boxes.conf
    
    print(f"检测到 {len(boxes)} 个目标")
    for i in range(len(boxes)):
        class_name = model.names[int(classes[i])]
        conf = confidences[i]
        print(f"  {class_name}: {conf:.2f}")

性能基准

性能评估标准

应用场景 mAP50目标 mAP50-95目标 说明
快速原型 > 0.5 > 0.3 验证想法
生产环境 > 0.7 > 0.5 实际应用
高精度应用 > 0.9 > 0.7 关键应用

真实案例

某工业质检项目:

  • 初始模型:mAP50=0.65,无法满足生产要求
  • 优化后:mAP50=0.85,达到生产标准
  • 优化方法:提升数据质量,增加数据量,调整超参数

6.2 常见问题与解决方案

问题诊断流程

1. 准确率低(mAP < 0.5)

诊断步骤

# 1. 检查数据质量
# - 标注是否准确
# - 数据是否平衡
# - 场景是否多样

# 2. 检查模型训练
# - Loss是否正常下降
# - 是否训练充分
# - 学习率是否合适

# 3. 检查模型选择
# - 模型是否太小
# - 是否需要更大的模型

解决方案

  • 提升数据质量:重新检查标注,修正错误
  • 增加数据量:收集更多高质量数据
  • 使用更大的模型:从yolov8n升级到yolov8m
  • 调整超参数:学习率、批次大小等

2. 过拟合(训练Loss低,验证Loss高)

诊断

# 检查训练曲线
# - train/box_loss持续下降
# - val/box_loss先下降后上升
# - 训练集mAP高,验证集mAP低

解决方案

  • 增加数据量:收集更多数据
  • 数据增强:启用更多数据增强
  • 使用更小的模型:减少模型复杂度
  • 正则化:增加dropout或权重衰减
  • 早停:使用早停机制

3. 漏检率高(Recall低)

诊断

# 检查各类别召回率
for i, class_name in enumerate(model.names.values()):
    recall = metrics.box.r[i]
    if recall < 0.7:
        print(f"⚠️  {class_name}召回率低: {recall:.2f}")

可能原因

  • 数据不平衡(某些类别样本少)
  • 小目标检测困难
  • 阈值设置过高

解决方案

  • 平衡数据:增加少数类样本
  • 降低置信度阈值:conf=0.15-0.25
  • 使用更高分辨率:imgsz=1280
  • 数据增强:针对小目标增强

4. 误检率高(Precision低)

诊断

# 检查各类别精确率
for i, class_name in enumerate(model.names.values()):
    precision = metrics.box.p[i]
    if precision < 0.7:
        print(f"⚠️  {class_name}精确率低: {precision:.2f}")

可能原因

  • 负样本不足
  • 类别相似度高
  • 阈值设置过低

解决方案

  • 增加负样本:添加不包含目标的图片
  • 提高置信度阈值:conf=0.3-0.5
  • 细化类别:区分相似类别
  • 后处理优化:调整NMS阈值

5. 训练很慢或不收敛

诊断

# 检查训练过程
# - Loss是否下降
# - 学习率是否合适
# - GPU利用率是否高

解决方案

  • 使用GPU:确保使用GPU训练
  • 调整批次大小:根据GPU内存调整
  • 调整学习率:尝试不同学习率
  • 检查数据:确保数据格式正确

问题解决对照表

问题 症状 可能原因 解决方案
准确率低 mAP < 0.5 数据质量差、数据量不足 提升数据质量、增加数据量
过拟合 训练集好,验证集差 数据量不足、模型太大 增加数据、使用小模型、数据增强
漏检率高 Recall < 0.7 数据不平衡、阈值高 平衡数据、降低阈值
误检率高 Precision < 0.7 负样本不足、阈值低 增加负样本、提高阈值
训练慢 训练时间长 CPU训练、批次小 使用GPU、增大批次
不收敛 Loss不下降 学习率不当、数据问题 调整学习率、检查数据

6.3 模型优化

优化策略

1. 数据优化

增加数据量

  • 收集更多高质量数据
  • 使用数据增强(旋转、翻转、亮度等)
  • 从公开数据集补充数据

提升数据质量

  • 重新检查标注,修正错误
  • 统一标注标准
  • 平衡各类别数据

数据增强脚本

# 使用YOLOv8内置的数据增强
# 在训练时自动应用,无需手动处理
# 可通过参数调整:
model.train(
    hsv_h=0.015,    # 色调增强
    hsv_s=0.7,     # 饱和度增强
    hsv_v=0.4,     # 明度增强
    degrees=10,    # 旋转角度
    translate=0.1, # 平移
    scale=0.5,     # 缩放
    mosaic=1.0,    # Mosaic增强
    mixup=0.15,    # MixUp增强
)

2. 超参数优化

学习率优化

# 尝试不同的学习率
learning_rates = [0.001, 0.005, 0.01, 0.02]

for lr in learning_rates:
    model = YOLO('yolov8n.pt')
    results = model.train(
        data='dataset.yaml',
        epochs=50,
        lr0=lr,
        name=f'lr_{lr}',
    )
    print(f"LR={lr}, mAP50={results.results_dict['metrics/mAP50(B)']:.4f}")

批次大小优化

# 根据GPU内存调整批次大小
# 更大的批次通常更稳定,但需要更多内存
batch_sizes = [8, 16, 32]

for batch in batch_sizes:
    model = YOLO('yolov8n.pt')
    results = model.train(
        data='dataset.yaml',
        epochs=50,
        batch=batch,
        name=f'batch_{batch}',
    )

3. 模型选择优化

模型大小对比

# 测试不同大小的模型
models = ['yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt']

for model_name in models:
    model = YOLO(model_name)
    results = model.train(
        data='dataset.yaml',
        epochs=100,
        name=model_name.replace('.pt', ''),
    )
    print(f"{model_name}: mAP50={results.results_dict['metrics/mAP50(B)']:.4f}")

4. 后处理优化

调整置信度阈值

# 默认阈值是0.25,可以根据需求调整
# 提高阈值:减少误检,但可能增加漏检
# 降低阈值:减少漏检,但可能增加误检

# 推理时调整
results = model('test_image.jpg', conf=0.3)  # 提高阈值
results = model('test_image.jpg', conf=0.15)  # 降低阈值

调整NMS阈值

# NMS(Non-Maximum Suppression)用于去除重复检测
# iou参数控制NMS的IoU阈值
# 提高iou:更严格的NMS,减少重复检测
# 降低iou:更宽松的NMS,可能保留更多检测框

results = model('test_image.jpg', iou=0.45)  # 默认0.7

5. 模型集成

多模型投票

from ultralytics import YOLO
import numpy as np

# 加载多个模型
models = [
    YOLO('runs/detect/model1/weights/best.pt'),
    YOLO('runs/detect/model2/weights/best.pt'),
    YOLO('runs/detect/model3/weights/best.pt'),
]

# 对同一图片进行预测
image = 'test_image.jpg'
predictions = [model(image, conf=0.25) for model in models]

# 投票或平均(简化示例)
# 实际应用中需要更复杂的集成策略

优化检查清单

数据优化

  • 数据量是否足够
  • 数据质量是否高
  • 类别是否平衡
  • 场景是否多样

训练优化

  • 学习率是否合适
  • 批次大小是否合理
  • 训练轮数是否足够
  • 数据增强是否启用

模型优化

  • 模型大小是否合适
  • 是否使用预训练权重
  • 是否尝试不同模型

后处理优化

  • 置信度阈值是否合适
  • NMS阈值是否合适
  • 是否考虑模型集成

性能评估

  • mAP是否达到目标
  • Precision和Recall是否平衡
  • 各类别性能是否均衡
  • 实际应用效果是否满意

🎁 使用 TjMakeBot 加速数据集制作

TjMakeBot 的优势

  1. AI 聊天式标注

    • 自然语言指令,快速标注
    • 支持批量处理
    • 准确率高
  2. 视频转帧功能

    • 从视频中提取帧
    • 自定义帧率
    • 批量处理
  3. 多格式支持

    • YOLO 格式导出
    • VOC、COCO 格式支持
    • 格式转换便捷
  4. 免费(基础功能)

    • 无使用限制
    • 无功能限制
    • 在线即用

立即免费使用 TjMakeBot 制作 YOLO 数据集 →

📚 相关阅读

💬 结语

制作高质量的 YOLO 数据集是模型成功的基础。通过选择合适的工具、遵循实用方法、持续优化,你就能创建出高质量的数据集,训练出优秀的模型。

记住:数据质量 > 模型架构。投资时间在数据上,回报是显著的。


法律声明:本文内容仅供参考,不构成任何法律、商业或技术建议。使用任何工具或方法时,请遵守相关法律法规,尊重知识产权,获得必要的授权。本文提及的所有公司名称、产品名称和商标均为其各自所有者的财产。

关于作者:TjMakeBot 团队专注于 AI 数据标注工具开发,帮助开发者快速创建高质量的 YOLO 数据集。

📚 推荐阅读

关键词:YOLO数据集、目标检测、YOLO标注、YOLOv8、YOLOv5、数据集制作、图像标注、TjMakeBot