Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions science_officer/invasive_craba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Invasive Crab Detection
Examples:
```
python3 yolo_detect.py --model-path Standard_Model.pt --source-type vidoe --source rtsp://192.168.137.200:8889/cam --min-thresh 0.65 --resolution 1280x720
```

```
python3 yolo_detect.py --model-path Standard_Model.pt --source-type usb --source usb0 --min-thresh 0.65 --resolution 1280x720
```
Binary file added science_officer/invasive_craba/Standard_Model.pt
Binary file not shown.
186 changes: 186 additions & 0 deletions science_officer/invasive_craba/yolo_detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import os
os.environ["OPENCV_FFMPEG_CAPTURE_OPTIONS"] = "rtsp_transport;tcp|max_delay;0"
import sys
import argparse
import glob
import time

import cv2
import numpy as np
from ultralytics import YOLO

def parse_args():
parser = argparse.ArgumentParser(description="Scan for invasive crabs with BlueStar")

parser.add_argument(
"--model-path",
default="Standard_Model.pt",
help="Path to the Yolo model used"
)

parser.add_argument(
"--source-type",
default="video",
help="video or usb"
)

parser.add_argument(
"--source",
default="rtsp://192.168.137.200:8889/cam",
help="Video source (ie usb0, rtsp://192.168.137.200:8889/cam)"
)

parser.add_argument(
"--min-thresh",
default=0.65,
help="Minimum threshhold for marking as detected",
type=float
)

parser.add_argument(
"--resolution",
default="1260x720",
help="Source resolution"
)

return parser.parse_args()

args = parse_args()

# Check if model file exists and is valid
if (not os.path.exists(args.model_path)):
print('ERROR: Model path is invalid or model was not found. Make sure the model filename was entered correctly.')
sys.exit(0)

# Load the model into memory and get labemap
model = YOLO(args.model_path, task='detect')
labels = model.names

# Parse user-specified display resolution
resize = False
if args.resolution:
resize = True
resW, resH = int(args.resolution.split('x')[0]), int(args.resolution.split('x')[1])

# Load or initialize image source
if args.source_type == 'video' or args.source_type == 'usb':

if args.source_type == 'video': cap_arg = args.source_type
elif args.source_type == 'usb': cap_arg = int(args.source[3:])

cap = cv2.VideoCapture(cap_arg)

# Set camera or video resolution if specified by user
if args.resolution:
ret = cap.set(3, resW)
ret = cap.set(4, resH)

# Set bounding box colors (using the Tableu 10 color scheme)
bbox_colors = [(164,120,87), (68,148,228), (93,97,209), (178,182,133), (88,159,106),
(96,202,231), (159,124,168), (169,162,241), (98,118,150), (172,176,184)]

# Initialize control and status variables
avg_frame_rate = 0
frame_rate_buffer = []
fps_avg_len = 200
img_count = 0

# Begin inference loop
while True:
t_start = time.perf_counter()

# Load frame from image source
if args.source_type == 'video': # If source is a video, load next frame from video file
ret, frame = cap.read()
if not ret:
print('Reached end of the video file. Exiting program.')
break

elif args.source_type == 'usb': # If source is a USB camera, grab frame from camera
ret, frame = cap.read()
if (frame is None) or (not ret):
print('Unable to read frames from the camera. This indicates the camera is disconnected or not working. Exiting program.')
break

# Resize frame to desired display resolution
if resize == True:
frame = cv2.resize(frame,(resW,resH))

# Run inference on frame
results = model(frame, verbose=False)

# Extract results
detections = results[0].boxes

# Initialize variable for basic object counting example
object_count = 0

# Go through each detection and get bbox coords, confidence, and class
for i in range(len(detections)):

# Get bounding box coordinates
# Ultralytics returns results in Tensor format, which have to be converted to a regular Python array
xyxy_tensor = detections[i].xyxy.cpu() # Detections in Tensor format in CPU memory
xyxy = xyxy_tensor.numpy().squeeze() # Convert tensors to Numpy array
xmin, ymin, xmax, ymax = xyxy.astype(int) # Extract individual coordinates and convert to int

# Get bounding box class ID and name
classidx = int(detections[i].cls.item())
classname = labels[classidx]

# Get bounding box confidence
conf = detections[i].conf.item()

# Draw box if confidence threshold is high enough
if conf > args.min_thresh:

color = bbox_colors[classidx % 10]
cv2.rectangle(frame, (xmin,ymin), (xmax,ymax), color, 2)

label = f'{classname}: {int(conf*100)}%'
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) # Get font size
label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window
cv2.rectangle(frame, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), color, cv2.FILLED) # Draw white box to put label text in
cv2.putText(frame, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) # Draw label text

# Basic example: count the number of objects in the image
object_count = object_count + 1

# Calculate and draw framerate (if using video, USB, or Picamera source)
if args.source_type == 'video' or args.source_type == 'usb':
cv2.putText(frame, f'FPS: {avg_frame_rate:0.2f}', (10,20), cv2.FONT_HERSHEY_SIMPLEX, .7, (0,255,255), 2) # Draw framerate

# Display detection results
cv2.putText(frame, f'Number of objects: {object_count}', (10,40), cv2.FONT_HERSHEY_SIMPLEX, .7, (0,255,255), 2) # Draw total number of detected objects
cv2.imshow('YOLO detection results',frame) # Display image

# Wait 5ms before moving to next frame.
if args.source_type == 'video' or args.source_type == 'usb':
key = cv2.waitKey(5)

if key == ord('q') or key == ord('Q'): # Press 'q' to quit
break
elif key == ord('s') or key == ord('S'): # Press 's' to pause inference
cv2.waitKey()
elif key == ord('p') or key == ord('P'): # Press 'p' to save a picture of results on this frame
cv2.imwrite('capture.png',frame)

# Calculate FPS for this frame
t_stop = time.perf_counter()
frame_rate_calc = float(1/(t_stop - t_start))

# Append FPS result to frame_rate_buffer (for finding average FPS over multiple frames)
if len(frame_rate_buffer) >= fps_avg_len:
temp = frame_rate_buffer.pop(0)
frame_rate_buffer.append(frame_rate_calc)
else:
frame_rate_buffer.append(frame_rate_calc)

# Calculate average FPS for past frames
avg_frame_rate = np.mean(frame_rate_buffer)

# Clean up
print(f'Average pipeline FPS: {avg_frame_rate:.2f}')
if args.source_type == 'video' or args.source_type == 'usb':
cap.release()
cv2.destroyAllWindows()