Created
August 20, 2020 13:51
-
-
Save Namburger/ae6d01f3c67528e6ad7342d7a026fd9f to your computer and use it in GitHub Desktop.
classify_camera.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Taken from here: https://github.com/googlecreativelab/teachablemachine-community/blob/master/snippets/markdown/image/edgetpu/python.md just fixed some bugs | |
from edgetpu.classification.engine import ClassificationEngine | |
from PIL import Image | |
import cv2 | |
import re | |
import os | |
import sys | |
# the TFLite converted to be used with edgetpu | |
modelPath = sys.argv[1] | |
# The path to labels.txt that was downloaded with your model | |
labelPath = sys.argv[2] | |
# This function parses the labels.txt and puts it in a python dictionary | |
def loadLabels(labelPath): | |
p = re.compile(r'\s*(\d+)(.+)') | |
with open(labelPath, 'r', encoding='utf-8') as labelFile: | |
lines = (p.match(line).groups() for line in labelFile.readlines()) | |
return {int(num): text.strip() for num, text in lines} | |
# This function takes in a PIL Image from any source or path you choose | |
def classifyImage(image, engine): | |
# Load and format your image for use with TM2 model | |
# image is reformated to a square to match training | |
image.resize((224, 224)) | |
# Classify and ouptut inference | |
classifications = engine.classify_with_image(image, top_k=1) | |
return classifications | |
def main(): | |
# Load your model onto your Coral Edgetpu | |
engine = ClassificationEngine(modelPath) | |
labels = loadLabels(labelPath) | |
cap = cv2.VideoCapture(0) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Format the image into a PIL Image so its compatable with Edge TPU | |
cv2_im = frame | |
pil_im = Image.fromarray(cv2_im) | |
# Resize and flip image so its a square and matches training | |
pil_im.resize((224, 224)) | |
pil_im.transpose(Image.FLIP_LEFT_RIGHT) | |
# Classify and display image | |
results = classifyImage(pil_im, engine) | |
cv2.imshow('frame', cv2_im) | |
if results: print('Classification:', labels[results[0][0]], 'score:', str(results[0][1])) | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
cap.release() | |
cv2.destroyAllWindows() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment