diff --git a/science_officer/invasive_craba/README.md b/science_officer/invasive_craba/README.md new file mode 100644 index 0000000..67641ed --- /dev/null +++ b/science_officer/invasive_craba/README.md @@ -0,0 +1,9 @@ +# Invasive Crab Detection +Examples: +``` +python3 yolo_detect.py --model-path Standard_Model.pt --source-type vidoe --source rtsp://192.168.137.200:8889/cam --min-thresh 0.65 --resolution 1280x720 +``` + +``` +python3 yolo_detect.py --model-path Standard_Model.pt --source-type usb --source usb0 --min-thresh 0.65 --resolution 1280x720 +``` \ No newline at end of file diff --git a/science_officer/invasive_craba/Standard_Model.pt b/science_officer/invasive_craba/Standard_Model.pt new file mode 100644 index 0000000..208aada Binary files /dev/null and b/science_officer/invasive_craba/Standard_Model.pt differ diff --git a/science_officer/invasive_craba/yolo_detect.py b/science_officer/invasive_craba/yolo_detect.py new file mode 100644 index 0000000..7afb176 --- /dev/null +++ b/science_officer/invasive_craba/yolo_detect.py @@ -0,0 +1,186 @@ +import os +os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;tcp|max_delay;0" +import sys +import argparse +import glob +import time + +import cv2 +import numpy as np +from ultralytics import YOLO + +def parse_args(): + parser = argparse.ArgumentParser(description="Scan for invasive crabs with BlueStar") + + parser.add_argument( + "--model-path", + default="Standard_Model.pt", + help="Path to the Yolo model used" + ) + + parser.add_argument( + "--source-type", + default="video", + help="video or usb" + ) + + parser.add_argument( + "--source", + default="rtsp://192.168.137.200:8889/cam", + help="Video source (ie usb0, rtsp://192.168.137.200:8889/cam)" + ) + + parser.add_argument( + "--min-thresh", + default=0.65, + help="Minimum threshhold for marking as detected", + type=float + ) + + parser.add_argument( + "--resolution", + default="1260x720", + help="Source resolution" + ) + + return parser.parse_args() + +args = parse_args() + +# Check if model file exists and is valid +if (not os.path.exists(args.model_path)): + print('ERROR: Model path is invalid or model was not found. Make sure the model filename was entered correctly.') + sys.exit(0) + +# Load the model into memory and get labemap +model = YOLO(args.model_path, task='detect') +labels = model.names + +# Parse user-specified display resolution +resize = False +if args.resolution: + resize = True + resW, resH = int(args.resolution.split('x')[0]), int(args.resolution.split('x')[1]) + +# Load or initialize image source +if args.source_type == 'video' or args.source_type == 'usb': + + if args.source_type == 'video': cap_arg = args.source_type + elif args.source_type == 'usb': cap_arg = int(args.source[3:]) + + cap = cv2.VideoCapture(cap_arg) + + # Set camera or video resolution if specified by user + if args.resolution: + ret = cap.set(3, resW) + ret = cap.set(4, resH) + +# Set bounding box colors (using the Tableu 10 color scheme) +bbox_colors = [(164,120,87), (68,148,228), (93,97,209), (178,182,133), (88,159,106), + (96,202,231), (159,124,168), (169,162,241), (98,118,150), (172,176,184)] + +# Initialize control and status variables +avg_frame_rate = 0 +frame_rate_buffer = [] +fps_avg_len = 200 +img_count = 0 + +# Begin inference loop +while True: + t_start = time.perf_counter() + + # Load frame from image source + if args.source_type == 'video': # If source is a video, load next frame from video file + ret, frame = cap.read() + if not ret: + print('Reached end of the video file. Exiting program.') + break + + elif args.source_type == 'usb': # If source is a USB camera, grab frame from camera + ret, frame = cap.read() + if (frame is None) or (not ret): + print('Unable to read frames from the camera. This indicates the camera is disconnected or not working. Exiting program.') + break + + # Resize frame to desired display resolution + if resize == True: + frame = cv2.resize(frame,(resW,resH)) + + # Run inference on frame + results = model(frame, verbose=False) + + # Extract results + detections = results[0].boxes + + # Initialize variable for basic object counting example + object_count = 0 + + # Go through each detection and get bbox coords, confidence, and class + for i in range(len(detections)): + + # Get bounding box coordinates + # Ultralytics returns results in Tensor format, which have to be converted to a regular Python array + xyxy_tensor = detections[i].xyxy.cpu() # Detections in Tensor format in CPU memory + xyxy = xyxy_tensor.numpy().squeeze() # Convert tensors to Numpy array + xmin, ymin, xmax, ymax = xyxy.astype(int) # Extract individual coordinates and convert to int + + # Get bounding box class ID and name + classidx = int(detections[i].cls.item()) + classname = labels[classidx] + + # Get bounding box confidence + conf = detections[i].conf.item() + + # Draw box if confidence threshold is high enough + if conf > args.min_thresh: + + color = bbox_colors[classidx % 10] + cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), color, 2) + + label = f'{classname}: {int(conf*100)}%' + labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) # Get font size + label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window + cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), color, cv2.FILLED) # Draw white box to put label text in + cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) # Draw label text + + # Basic example: count the number of objects in the image + object_count = object_count + 1 + + # Calculate and draw framerate (if using video, USB, or Picamera source) + if args.source_type == 'video' or args.source_type == 'usb': + cv2.putText(frame, f'FPS: {avg_frame_rate:0.2f}', (10,20), cv2.FONT_HERSHEY_SIMPLEX, .7, (0,255,255), 2) # Draw framerate + + # Display detection results + cv2.putText(frame, f'Number of objects: {object_count}', (10,40), cv2.FONT_HERSHEY_SIMPLEX, .7, (0,255,255), 2) # Draw total number of detected objects + cv2.imshow('YOLO detection results',frame) # Display image + + # Wait 5ms before moving to next frame. + if args.source_type == 'video' or args.source_type == 'usb': + key = cv2.waitKey(5) + + if key == ord('q') or key == ord('Q'): # Press 'q' to quit + break + elif key == ord('s') or key == ord('S'): # Press 's' to pause inference + cv2.waitKey() + elif key == ord('p') or key == ord('P'): # Press 'p' to save a picture of results on this frame + cv2.imwrite('capture.png',frame) + + # Calculate FPS for this frame + t_stop = time.perf_counter() + frame_rate_calc = float(1/(t_stop - t_start)) + + # Append FPS result to frame_rate_buffer (for finding average FPS over multiple frames) + if len(frame_rate_buffer) >= fps_avg_len: + temp = frame_rate_buffer.pop(0) + frame_rate_buffer.append(frame_rate_calc) + else: + frame_rate_buffer.append(frame_rate_calc) + + # Calculate average FPS for past frames + avg_frame_rate = np.mean(frame_rate_buffer) + +# Clean up +print(f'Average pipeline FPS: {avg_frame_rate:.2f}') +if args.source_type == 'video' or args.source_type == 'usb': + cap.release() +cv2.destroyAllWindows() \ No newline at end of file