Skip to content

Instantly share code, notes, and snippets.

@Namburger
Created August 20, 2020 13:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Namburger/ae6d01f3c67528e6ad7342d7a026fd9f to your computer and use it in GitHub Desktop.
Save Namburger/ae6d01f3c67528e6ad7342d7a026fd9f to your computer and use it in GitHub Desktop.
classify_camera.py
# Taken from here: https://github.com/googlecreativelab/teachablemachine-community/blob/master/snippets/markdown/image/edgetpu/python.md just fixed some bugs
from edgetpu.classification.engine import ClassificationEngine
from PIL import Image
import cv2
import re
import os
import sys
# the TFLite converted to be used with edgetpu
modelPath = sys.argv[1]
# The path to labels.txt that was downloaded with your model
labelPath = sys.argv[2]
# This function parses the labels.txt and puts it in a python dictionary
def loadLabels(labelPath):
p = re.compile(r'\s*(\d+)(.+)')
with open(labelPath, 'r', encoding='utf-8') as labelFile:
lines = (p.match(line).groups() for line in labelFile.readlines())
return {int(num): text.strip() for num, text in lines}
# This function takes in a PIL Image from any source or path you choose
def classifyImage(image, engine):
# Load and format your image for use with TM2 model
# image is reformated to a square to match training
image.resize((224, 224))
# Classify and ouptut inference
classifications = engine.classify_with_image(image, top_k=1)
return classifications
def main():
# Load your model onto your Coral Edgetpu
engine = ClassificationEngine(modelPath)
labels = loadLabels(labelPath)
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Format the image into a PIL Image so its compatable with Edge TPU
cv2_im = frame
pil_im = Image.fromarray(cv2_im)
# Resize and flip image so its a square and matches training
pil_im.resize((224, 224))
pil_im.transpose(Image.FLIP_LEFT_RIGHT)
# Classify and display image
results = classifyImage(pil_im, engine)
cv2.imshow('frame', cv2_im)
if results: print('Classification:', labels[results[0][0]], 'score:', str(results[0][1]))
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment