ボールがある方向を向く(パン&チルト)

touch-sp.hatenablog.com
以前は1軸であったが、今回は2軸にした。
f:id:touch-sp:20200611215803j:plain:w320
ボールを検出するために学習済みSSDを使って転移学習を行った。
こちらを参照)

Pythonコード

import mxnet as mx
import gluoncv

import serial, time
import cv2, queue, threading

class VideoCapture:

  def __init__(self, name):
    self.cap = cv2.VideoCapture(name)
    self.q = queue.Queue()
    t = threading.Thread(target=self._reader)
    t.daemon = True
    t.start()

  # read frames as soon as they are available, keeping only most recent one
  def _reader(self):
    while True:
      ret, frame = self.cap.read()
      if not ret:
        break
      if not self.q.empty():
        try:
          self.q.get_nowait()   # discard previous (unprocessed) frame
        except queue.Empty:
          pass
      self.q.put(frame)

  def read(self):
    return self.q.get()

ser =serial.Serial("COM4", 9600)
time.sleep(1.5)

#初期設定
servo_yoko = 90
servo_tate = 90

ctx = mx.gpu()
# Load the model
classes = ['ball']
net = gluoncv.model_zoo.get_model('ssd_512_mobilenet1.0_custom', classes=classes, pretrained=False, root='./model')
net.load_parameters('ssd_512_mobilenet1.0_ball.params')
net.collect_params().reset_ctx(ctx)
# Compile the model for faster speed
net.hybridize()

cap = VideoCapture('rtsp://192.72.1.1:554/liveRTSP/av4/track0')
time.sleep(1)

while True:

    frame = cap.read()

    frame = mx.nd.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).astype('uint8')    
    rgb_nd, frame = gluoncv.data.transforms.presets.ssd.transform_test(frame, short=320)
  
    # Run frame through network
    class_IDs, scores, bounding_boxes = net(rgb_nd.as_in_context(ctx))

    if scores[0][0] > 0.55:
        x_min = bounding_boxes[0][0][0]
        y_min = bounding_boxes[0][0][1]
        x_max = bounding_boxes[0][0][2]
        y_max = bounding_boxes[0][0][3]
    
        if x_min > 360:
            if x_min > 480:
                servo_yoko  += 15
            else:
                servo_yoko += 3
        if x_max < 280:
            if x_max <160:
                servo_yoko -= 15
            else:
                servo_yoko -= 3
              
        servo_yoko = 160 if servo_yoko > 160 else servo_yoko
        servo_yoko = 20 if servo_yoko < 20 else servo_yoko

        if y_min > 180:
            if y_min > 270:
                servo_tate  += 5
            else:
                servo_tate += 3

        if y_max < 180:
            if y_max <90:
                servo_tate -= 5
            else:
                servo_tate -= 3
              
        servo_tate = 110 if servo_tate > 110 else servo_tate
        servo_tate = 70 if servo_tate < 70 else servo_tate

        #bufferがゼロになるまで待つ
        finished = False
        while not finished:
            finished = (ser.out_waiting == 0)
        
        send_data_yoko = servo_yoko.to_bytes(1, 'big')
        send_data_tate = servo_tate.to_bytes(1, 'big')

        ser.write(send_data_yoko)
        ser.write(send_data_tate)
        ser.write((255).to_bytes(1, 'big'))

    # Display the result
    img = gluoncv.utils.viz.cv_plot_bbox(frame, bounding_boxes[0], scores[0], class_IDs[0], class_names=net.classes, thresh=0.55)
    gluoncv.utils.viz.cv_plot_image(img)
    
    # escを押したら終了
    if cv2.waitKey(1) == 27:
        break

    time.sleep(0.2)

ser.close()
cap.release()
cv2.destroyAllWindows()

Arduinoスケッチ

#include <Servo.h>

Servo myServo_yoko;
Servo myServo_tate;
int yoko = 90;
int tate = 90;
int dummy;

void setup() {
  myServo_yoko.attach(9);
  myServo_tate.attach(10);
  myServo_yoko.write(yoko);
  myServo_tate.write(tate);
  Serial.begin(9600);
}

void loop() {
  if(Serial.available()>2){
    yoko = Serial.read();
    tate = Serial.read();
    myServo_yoko.write(yoko);
    myServo_tate.write(tate);
    delay(100);
    dummy = Serial.read();
  }
}