/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ============================================================================== */ package com.evetensorflowpose.ml import android.content.Context import android.graphics.Bitmap import android.graphics.PointF import android.os.SystemClock import android.util.Log import org.tensorflow.lite.DataType import org.tensorflow.lite.Interpreter import com.evetensorflowpose.data.BodyPart import com.evetensorflowpose.data.Device import com.evetensorflowpose.data.KeyPoint import com.evetensorflowpose.data.Person import org.tensorflow.lite.gpu.GpuDelegate import org.tensorflow.lite.support.common.FileUtil import org.tensorflow.lite.support.common.ops.NormalizeOp import org.tensorflow.lite.support.image.ImageProcessor import org.tensorflow.lite.support.image.TensorImage import org.tensorflow.lite.support.image.ops.ResizeOp import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp import kotlin.math.exp class PoseNet(private val interpreter: Interpreter, private var gpuDelegate: GpuDelegate?) : PoseDetector { companion object { private const val CPU_NUM_THREADS = 4 private const val MEAN = 127.5f private const val STD = 127.5f private const val TAG = "Posenet" private const val MODEL_FILENAME = "posenet.tflite" fun create(context: Context, device: Device): PoseNet { val options = Interpreter.Options() var gpuDelegate: GpuDelegate? = null options.setNumThreads(CPU_NUM_THREADS) when (device) { Device.CPU -> { } Device.GPU -> { gpuDelegate = GpuDelegate() options.addDelegate(gpuDelegate) } Device.NNAPI -> options.setUseNNAPI(true) } return PoseNet( Interpreter( FileUtil.loadMappedFile( context, MODEL_FILENAME ), options ), gpuDelegate ) } } private var lastInferenceTimeNanos: Long = -1 private val inputWidth = interpreter.getInputTensor(0).shape()[1] private val inputHeight = interpreter.getInputTensor(0).shape()[2] private var cropHeight = 0f private var cropWidth = 0f private var cropSize = 0 @Suppress("UNCHECKED_CAST") override fun estimatePoses(bitmap: Bitmap): List { val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos() val inputArray = arrayOf(processInputImage(bitmap).tensorBuffer.buffer) Log.i( TAG, String.format( "Scaling to [-1,1] took %.2f ms", (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000f ) ) val outputMap = initOutputMap(interpreter) val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos() interpreter.runForMultipleInputsOutputs(inputArray, outputMap) lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos Log.i( TAG, String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000) ) val heatmaps = outputMap[0] as Array>> val offsets = outputMap[1] as Array>> val postProcessingStartTimeNanos = SystemClock.elapsedRealtimeNanos() val person = postProcessModelOuputs(heatmaps, offsets) Log.i( TAG, String.format( "Postprocessing took %.2f ms", (SystemClock.elapsedRealtimeNanos() - postProcessingStartTimeNanos) / 1_000_000f ) ) return listOf(person) } /** * Convert heatmaps and offsets output of Posenet into a list of keypoints */ private fun postProcessModelOuputs( heatmaps: Array>>, offsets: Array>> ): Person { val height = heatmaps[0].size val width = heatmaps[0][0].size val numKeypoints = heatmaps[0][0][0].size // Finds the (row, col) locations of where the keypoints are most likely to be. val keypointPositions = Array(numKeypoints) { Pair(0, 0) } for (keypoint in 0 until numKeypoints) { var maxVal = heatmaps[0][0][0][keypoint] var maxRow = 0 var maxCol = 0 for (row in 0 until height) { for (col in 0 until width) { if (heatmaps[0][row][col][keypoint] > maxVal) { maxVal = heatmaps[0][row][col][keypoint] maxRow = row maxCol = col } } } keypointPositions[keypoint] = Pair(maxRow, maxCol) } // Calculating the x and y coordinates of the keypoints with offset adjustment. val xCoords = IntArray(numKeypoints) val yCoords = IntArray(numKeypoints) val confidenceScores = FloatArray(numKeypoints) keypointPositions.forEachIndexed { idx, position -> val positionY = keypointPositions[idx].first val positionX = keypointPositions[idx].second val inputImageCoordinateY = position.first / (height - 1).toFloat() * inputHeight + offsets[0][positionY][positionX][idx] val ratioHeight = cropSize.toFloat() / inputHeight val paddingHeight = cropHeight / 2 yCoords[idx] = (inputImageCoordinateY * ratioHeight - paddingHeight).toInt() val inputImageCoordinateX = position.second / (width - 1).toFloat() * inputWidth + offsets[0][positionY][positionX][idx + numKeypoints] val ratioWidth = cropSize.toFloat() / inputWidth val paddingWidth = cropWidth / 2 xCoords[idx] = (inputImageCoordinateX * ratioWidth - paddingWidth).toInt() confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx]) } val keypointList = mutableListOf() var totalScore = 0.0f enumValues().forEachIndexed { idx, it -> keypointList.add( KeyPoint( it, PointF(xCoords[idx].toFloat(), yCoords[idx].toFloat()), confidenceScores[idx] ) ) totalScore += confidenceScores[idx] } return Person(keyPoints = keypointList.toList(), score = totalScore / numKeypoints) } override fun lastInferenceTimeNanos(): Long = lastInferenceTimeNanos override fun close() { gpuDelegate?.close() interpreter.close() } /** * Scale and crop the input image to a TensorImage. */ private fun processInputImage(bitmap: Bitmap): TensorImage { // reset crop width and height cropWidth = 0f cropHeight = 0f cropSize = if (bitmap.width > bitmap.height) { cropHeight = (bitmap.width - bitmap.height).toFloat() bitmap.width } else { cropWidth = (bitmap.height - bitmap.width).toFloat() bitmap.height } val imageProcessor = ImageProcessor.Builder().apply { add(ResizeWithCropOrPadOp(cropSize, cropSize)) add(ResizeOp(inputWidth, inputHeight, ResizeOp.ResizeMethod.BILINEAR)) add(NormalizeOp(MEAN, STD)) }.build() val tensorImage = TensorImage(DataType.FLOAT32) tensorImage.load(bitmap) return imageProcessor.process(tensorImage) } /** * Initializes an outputMap of 1 * x * y * z FloatArrays for the model processing to populate. */ private fun initOutputMap(interpreter: Interpreter): HashMap { val outputMap = HashMap() // 1 * 9 * 9 * 17 contains heatmaps val heatmapsShape = interpreter.getOutputTensor(0).shape() outputMap[0] = Array(heatmapsShape[0]) { Array(heatmapsShape[1]) { Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) } } } // 1 * 9 * 9 * 34 contains offsets val offsetsShape = interpreter.getOutputTensor(1).shape() outputMap[1] = Array(offsetsShape[0]) { Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } } } // 1 * 9 * 9 * 32 contains forward displacements val displacementsFwdShape = interpreter.getOutputTensor(2).shape() outputMap[2] = Array(offsetsShape[0]) { Array(displacementsFwdShape[1]) { Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) } } } // 1 * 9 * 9 * 32 contains backward displacements val displacementsBwdShape = interpreter.getOutputTensor(3).shape() outputMap[3] = Array(displacementsBwdShape[0]) { Array(displacementsBwdShape[1]) { Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) } } } return outputMap } /** Returns value within [0,1]. */ private fun sigmoid(x: Float): Float { return (1.0f / (1.0f + exp(-x))) } }