← 返回文章详情

track_pc.py

相关文档预览

import cv2
import numpy as np
import time
from ultralytics import YOLO
"""
测试步骤:
创建虚拟环境: python -m venv venv
激活虚拟环境: .venv\Scripts\activate
安装依赖: pip install opencv-contrib-python ultralytics numpy
运行程序: python track.py
请用纸团和瓶子进行测试,屏幕上会显示追踪到的目标以及预测的400ms后它的位置。
控制台会给出四个麦轮的转动方向、速度和时间指令。
按q退出程序。
若要修改模型,请替换 'best.pt' 文件路径(第26行和288行)并修改 TARGET_CLASS_ID 变量(第42行)。
画面中的蓝色框表示检测到的动态物体,绿色框才表示进入追踪状态的目标,白色的点表示预测的位置。
"""
class HybridTracker:
    """
    一个完全封装的、带状态机的智能追踪器。
    - 使用稠密光流+YOLO进行目标确认。
    - 使用CSRT进行预测性追踪,并通过YOLO周期性再确认。
    - 通过边界检查、静止检测和视觉丢失容忍来判断追踪失败。
    - 内部处理所有控制逻辑,包括带智能记忆的惯性航行。
    - 对外只提供一个 process_frame 接口,始终返回当前最合理的电机指令。
    """
    def __init__(self, model_path='best.pt'):
        """
        初始化追踪器,加载模型和摄像头,设置所有核心参数。
        :param model_path: YOLO模型的路径。
        """
        self.cap = cv2.VideoCapture(0)
        if not self.cap.isOpened(): print("错误:找不到摄像头"); raise SystemExit
        self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
        self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
        self.frame_width, self.frame_height = 640, 480
        self.frame_center_x, self.frame_center_y = self.frame_width // 2, self.frame_height // 2
        try:
            self.yolo_model = YOLO(model_path)
            print(f"YOLO模型 '{model_path}' 加载成功。"); print("模型可识别类别:", self.yolo_model.names)
        except Exception as e:
            print(f"错误: 加载YOLO模型失败: {e}"); raise SystemExit
        self.TARGET_CLASS_ID = 1
        self.CONFIDENCE_THRESHOLD = 0.25
        self.tracker, self.tracking, self.bbox = None, False, None
        self.MIN_TARGET_AREA, self.MAX_TARGET_AREA = 500, 25000
        self.MOTION_DETECTION_THRESHOLD = 4.0
        self.STATIC_FRAMES_TOLERANCE, self.STATIC_PIXEL_TOLERANCE = 5, 2
        self.TRACKING_LOSS_TOLERANCE = 10
        self.last_bbox_center, self.static_frames_count, self.lost_frames_count = None, 0, 0
        self.prev_gray = None
        self.last_center, self.last_time = None, None
        self.pixel_velocity, self.velocity_smoothing_factor = (0, 0), 0.7
        self.prediction_time_ms = 400
        self.COASTING_DURATION = 0.5
        self.dead_zone_radius = 15
        self.is_coasting = False
        self.coasting_end_time = 0.0
        self.last_motion_control_map = {}
        
        # --- 【新增】用于YOLO再确认的变量 ---
        self.reconfirm_frame_count = 0
        self.RECONFIRM_INTERVAL = 10 # 每10帧进行一次YOLO再确认

    def cleanup(self):
        """
        释放摄像头资源。
        """
        self.cap.release()
        cv2.destroyAllWindows()

    def _update_velocity(self, current_center):
        """
        根据连续帧的中心点变化,计算并平滑目标的速度向量。
        :param current_center: 目标当前的中心点坐标。
        """
        current_time = time.time()
        vx, vy = 0, 0
        if self.last_center is not None and self.last_time is not None:
            dt = current_time - self.last_time
            if dt > 1e-6:
                vx = (current_center[0] - self.last_center[0]) / dt
                vy = (current_center[1] - self.last_center[1]) / dt
        s = self.velocity_smoothing_factor
        self.pixel_velocity = (self.pixel_velocity[0] * s + vx * (1 - s), self.pixel_velocity[1] * s + vy * (1 - s))
        self.last_center, self.last_time = current_center, current_time

    def _track_target(self, frame):
        """
        当处于追踪状态时,调用此函数。
        使用CSRT追踪器更新目标位置,并执行多种理智检查(视觉丢失、边界碰撞、静止、YOLO再确认)。
        :param frame: 当前的彩色视频帧。
        :return: (bool: 是否成功, tuple: 中心点, tuple: 边界框)
        """
        success, bbox = self.tracker.update(frame)

        if not success:
            self.lost_frames_count += 1
            if self.lost_frames_count > self.TRACKING_LOSS_TOLERANCE:
                self._reset_tracker()
                return False, (0,0), None
            else:
                return False, (0,0), self.bbox
        
        self.lost_frames_count = 0
        x, y, w, h = map(int, bbox)
        frame_height, frame_width, _ = frame.shape
        
        sane = True
        if w <= 0 or h <= 0 or x <= 2 or y <= 2 or (x + w) >= frame_width - 2 or (y + h) >= frame_height - 2: 
            sane = False
        current_center = (x + w // 2, y + h // 2)
        if self.last_bbox_center is not None:
            dist = np.sqrt((current_center[0] - self.last_bbox_center[0])**2 + (current_center[1] - self.last_bbox_center[1])**2)
            if dist < self.STATIC_PIXEL_TOLERANCE: self.static_frames_count += 1
            else: self.static_frames_count = 0
        self.last_bbox_center = current_center
        if self.static_frames_count > self.STATIC_FRAMES_TOLERANCE:
            print("\n检测到追踪框静止,判定为目标丢失...")
            sane = False
        
        # --- 周期性的YOLO再确认 ---
        self.reconfirm_frame_count += 1
        if sane and self.reconfirm_frame_count >= self.RECONFIRM_INTERVAL:
            self.reconfirm_frame_count = 0 # 重置计数器
            roi = frame[y:y+h, x:x+w]
            if roi.size > 0:
                results = self.yolo_model(roi, verbose=False)
                is_still_target = any(int(box.cls) == self.TARGET_CLASS_ID and box.conf > self.CONFIDENCE_THRESHOLD for result in results for box in result.boxes)
                if not is_still_target:
                    print(f"\nYOLO再确认失败!追踪器可能已粘滞。")
                    sane = False

        if sane:
            self.bbox = (x, y, w, h)
            center = (x + w // 2, y + h // 2)
            return True, center, self.bbox
        else:
            self._reset_tracker()
            return False, (0,0), None

    def _detect_new_target(self, frame, frame_gray):
        """
        当处于搜索状态时,调用此函数。
        使用稠密光流发现运动,并用YOLO确认运动物体是否为目标。
        :param frame: 当前的彩色视频帧。
        :param frame_gray: 当前的灰度视频帧。
        :return: (bool: 是否成功, tuple: 中心点, tuple: 边界框)
        """
        flow = cv2.calcOpticalFlowFarneback(self.prev_gray, frame_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
        magnitude, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        motion_mask = (magnitude > self.MOTION_DETECTION_THRESHOLD).astype(np.uint8) * 255
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        motion_mask = cv2.morphologyEx(motion_mask, cv2.MORPH_CLOSE, kernel)
        contours, _ = cv2.findContours(motion_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        detection_bbox_for_debug = None
        if not contours: return False, (0,0), None, motion_mask, detection_bbox_for_debug
        largest_contour = max(contours, key=cv2.contourArea)
        area = cv2.contourArea(largest_contour)
        x, y, w, h = cv2.boundingRect(largest_contour)
        detection_bbox_for_debug = (x, y, w, h)
        if self.MIN_TARGET_AREA < area < self.MAX_TARGET_AREA:
            roi = frame[y:y+h, x:x+w]
            if roi.size == 0: return False, (0,0), None, motion_mask, detection_bbox_for_debug
            results = self.yolo_model(roi, verbose=False)
            is_target_confirmed = any(int(box.cls) == self.TARGET_CLASS_ID and box.conf > self.CONFIDENCE_THRESHOLD for result in results for box in result.boxes)
            if is_target_confirmed:
                print("\n--- YOLO确认到目标 ---")
                self._initialize_tracker(frame, (x, y, w, h))
                center = (x + w // 2, y + h // 2)
                return True, center, self.bbox, motion_mask, detection_bbox_for_debug
        return False, (0,0), None, motion_mask, detection_bbox_for_debug

    def _initialize_tracker(self, frame, bbox):
        """
        初始化CSRT追踪器并重置所有相关状态变量。
        :param frame: 用于初始化的帧。
        :param bbox: 目标的初始边界框。
        """
        self.tracker = cv2.TrackerCSRT_create()
        self.tracker.init(frame, bbox)
        self.bbox, self.tracking = bbox, True
        center = (bbox[0] + bbox[2] // 2, bbox[1] + bbox[3] // 2)
        self.last_bbox_center = center
        self.last_center, self.last_time = center, time.time()
        self.pixel_velocity = (0, 0); self.static_frames_count = 0; self.lost_frames_count = 0
        self.reconfirm_frame_count = 0 

    def _reset_tracker(self):
        """
        重置追踪状态,清空所有追踪相关变量。
        """
        self.tracking, self.tracker, self.bbox, self.last_bbox_center = False, None, None, None
        self.static_frames_count = 0; self.pixel_velocity = (0, 0); self.lost_frames_count = 0

    def _calculate_mecanum_control(self, predicted_center):
        """
        计算四个麦轮的控制命令
        所有返回的数字都是标准的Python int类型。
        """
        dx = predicted_center[0] - self.frame_center_x
        dy = predicted_center[1] - self.frame_center_y

        if np.sqrt(dx**2 + dy**2) < self.dead_zone_radius:
            return {i: [0, 0, 40] for i in range(4)}

        Kp_translation, Kp_rotation = 0.4, 0.3
        vx, vy, v_rot = dy * Kp_translation, dx * Kp_translation, dx * Kp_rotation
        motor_speeds = [vx - vy - v_rot, vx + vy + v_rot, vx + vy - v_rot, vx - vy + v_rot]

        max_possible_speed = np.sqrt(self.frame_center_x**2 + self.frame_center_y**2) * Kp_translation
    
        control_map = {}
        COMMAND_DURATION_MS = 40
        for i in range(4):
            speed_percent_float = (motor_speeds[i] / max_possible_speed) * 100
            speed_percent_clipped = np.clip(speed_percent_float, -100, 100)

            action = 0 if speed_percent_clipped >= 0 else 1
            final_speed = int(abs(speed_percent_clipped)) 
        
            if final_speed < 5: 
                final_speed = 0
            
            control_map[i] = [int(action), int(final_speed), int(COMMAND_DURATION_MS)]
        
        return control_map

    def _is_motion_command(self, control_map):
        """
        辅助函数,判断一个控制指令是否包含实际运动。
        :param control_map: 电机指令字典。
        :return: bool
        """
        if not control_map: return False
        return any(control_map[motor_id][1] > 0 for motor_id in control_map)

    def process_frame(self, frame):
        """
        唯一的对外接口。处理一帧图像,并返回最终的电机控制指令。
        内部自动处理追踪、丢失和带智能记忆的惯性航行。
        :param frame: 从摄像头读取的原始视频帧。
        :return: 一个电机控制指令字典,如果无操作则为空字典。
        """
        was_tracking_before_this_frame = self.tracking
        if self.prev_gray is None:
            self.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            return {}, False, None, None, None, None, None, False
        frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        detected, center, bbox = False, None, None
        motion_mask, detection_bbox, predicted_center = None, None, None
        if self.tracking:
            detected, center, bbox = self._track_target(frame)
        else:
            detected, center, bbox, motion_mask, detection_bbox = self._detect_new_target(frame, frame_gray)
        final_control_map = {}
        if detected:
            if self.is_coasting: print("\n航行中重新捕获目标!恢复精确追踪。")
            self.is_coasting = False
            self._update_velocity(center)
            vx_pix, vy_pix = self.pixel_velocity
            pred_time_sec = self.prediction_time_ms / 1000.0
            predicted_center = (int(center[0] + vx_pix * pred_time_sec), int(center[1] + vy_pix * pred_time_sec))
            current_control_map = self._calculate_mecanum_control(predicted_center)
            final_control_map = current_control_map
            if self._is_motion_command(current_control_map):
                self.last_motion_control_map = current_control_map
        else:
            if was_tracking_before_this_frame and not self.is_coasting:
                self.is_coasting = True
                self.coasting_end_time = time.time() + self.COASTING_DURATION
                print(f"\n视觉丢失!进入惯性航行,持续 {self.COASTING_DURATION} 秒...")
            if self.is_coasting:
                if time.time() < self.coasting_end_time:
                    final_control_map = self.last_motion_control_map
                else:
                    print("\n航行结束,目标未找回。停止电机。")
                    self.is_coasting = False
                    self.last_motion_control_map = {}
            else:
                final_control_map = {}
        self.prev_gray = frame_gray.copy()
        return final_control_map, detected, bbox, center, predicted_center, motion_mask, detection_bbox, self.is_coasting

def main():
    """
    主函数,初始化追踪器并处理视频流。
    """
    tracker = HybridTracker(model_path='best.pt')
    try:
        while True:
            ret, frame = tracker.cap.read()
            if not ret: print("无法读取视频帧,可能摄像头已断开。"); break
            final_control_map, detected, bbox, center, predicted_center, motion_mask, detection_bbox, is_coasting = tracker.process_frame(frame)
            if detected: print(f"追踪中... 指令: {final_control_map}", end='\r')
            elif is_coasting: print(f"惯性航行中... 指令: {final_control_map}", end='\r')
            else: print("正在搜索目标...                        ", end='\r')
            if detected and bbox:
                x, y, w, h = bbox
                cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
                cv2.circle(frame, center, 5, (0, 0, 255), -1)
                if predicted_center:
                    cv2.circle(frame, predicted_center, 7, (255, 255, 255), -1)
                    cv2.line(frame, center, predicted_center, (255, 255, 0), 2)
            elif is_coasting:
                cv2.putText(frame, "COASTING...", (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
            elif detection_bbox:
                x, y, w, h = detection_bbox
                cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
                cv2.putText(frame, "Motion Detected", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
            cv2.line(frame, (tracker.frame_center_x - 20, tracker.frame_center_y), (tracker.frame_center_x + 20, tracker.frame_center_y), (255, 0, 255), 2)
            cv2.line(frame, (tracker.frame_center_x, tracker.frame_center_y - 20), (tracker.frame_center_x, tracker.frame_center_y + 20), (255, 0, 255), 2)
            cv2.imshow('Hybrid Tracker - Debug Mode', frame)
            if motion_mask is not None:
                cv2.imshow('Motion Mask (Debug)', motion_mask)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                print("\n程序被用户中断。"); break
    except KeyboardInterrupt:
        print("\n程序被用户中断。")
    finally:
        tracker.cleanup()

if __name__ == '__main__':
    main()
返回顶部