更新opencv到4.8.0,增加onnxruntime用于推理onnx模型

This commit is contained in:
TonyJiangWJ 2023-08-24 23:55:40 +08:00
parent e842b26d21
commit bb1589fbf8
15 changed files with 588 additions and 6 deletions

View File

@ -35,10 +35,19 @@ public class ApkUnpackUtil {
if (!e.isDirectory() && !TextUtils.isEmpty(name)) {
File file = new File(mWorkspacePath, name);
System.out.println(file);
if (file.getParentFile() != null && file.getParentFile().mkdirs()) {
File parentFile = file.getParentFile();
if (parentFile != null && (parentFile.exists() || parentFile.mkdirs())) {
try (FileOutputStream fos = new FileOutputStream(file)) {
System.out.println(file.getName() + " has been written");
StreamUtils.write(zis, fos);
}
} else {
System.out.println(file.getName() + " can not write");
System.out.println("is parent null? " + (parentFile != null));
if (file.getParentFile() != null) {
System.out.println("is parent exists? " + parentFile.exists());
System.out.println("can parent mkdirs? " + parentFile.mkdirs());
}
}
} else {
System.out.println("file or empty" + name);

View File

@ -76,6 +76,7 @@ public class ApkBuilder {
Boolean usePaddleOcr = false;
Boolean useMlKitOcr = false;
Boolean useTessTwo = false;
Boolean useOnnx = false;
Set<String> enabledPermission = new HashSet<>();
public static AppConfig fromProjectConfig(String projectDir, ProjectConfig projectConfig) {
@ -187,6 +188,14 @@ public class ApkBuilder {
this.useTessTwo = useTessTwo;
}
public Boolean getUseOnnx() {
return useOnnx;
}
public void setUseOnnx(Boolean useOnnx) {
this.useOnnx = useOnnx;
}
public Set<String> getEnabledPermission() {
return enabledPermission;
}
@ -237,6 +246,9 @@ public class ApkBuilder {
if (!mAppConfig.useTessTwo) {
removeSoList.addAll(Arrays.asList("libtess.so", "liblept.so", "libjpgt.so", "libpngt.so"));
}
if (!mAppConfig.useOnnx) {
removeSoList.addAll(Arrays.asList("libonnxruntime.so", "libonnxruntime4j_jni.so"));
}
if (!mAppConfig.useOpenCv) {
removeSoList.add("libopencv_java4.so");
}

View File

@ -108,6 +108,9 @@ public class BuildActivity extends BaseActivity implements ApkBuilder.ProgressCa
@ViewById(R.id.use_tess_two)
CheckBox mUseTessTwo;
@ViewById(R.id.use_onnx_runtime)
CheckBox mUseOnnx;
@ViewById(R.id.recycler_view)
RecyclerView recyclerView;
@ -202,6 +205,7 @@ public class BuildActivity extends BaseActivity implements ApkBuilder.ProgressCa
"https://i.autojs.org/autojs/plugin/%d.apk", ApkBuilderPluginHelper.getSuitablePluginVersion()));
}
@SuppressLint("StringFormatInvalid")
private void setupWithSourceFile(ScriptFile file) {
String dir = file.getParent();
if (dir.startsWith(getFilesDir().getPath())) {
@ -349,6 +353,7 @@ public class BuildActivity extends BaseActivity implements ApkBuilder.ProgressCa
appConfig.setUsePaddleOcr(mUsePaddleOcr.isChecked());
appConfig.setUseMlKitOcr(mUseMlKitOcr.isChecked());
appConfig.setUseTessTwo(mUseTessTwo.isChecked());
appConfig.setUseOnnx(mUseOnnx.isChecked());
Set<String> enabledPermission = new HashSet<>();
for (Option option : options) {
if (option.isSelected()) {
@ -387,6 +392,7 @@ public class BuildActivity extends BaseActivity implements ApkBuilder.ProgressCa
Log.e(LOG_TAG, "Build failed", error);
}
@SuppressLint("StringFormatInvalid")
private void onBuildSuccessful(File outApk) {
mProgressDialog.dismiss();
mProgressDialog = null;

View File

@ -195,7 +195,7 @@ public class ProjectConfigActivity extends BaseActivity {
@Click(R.id.icon)
void selectIcon() {
ShortcutIconSelectActivity_.intent(this)
.flags(Intent.FLAG_ACTIVITY_NEW_TASK)
.flags(Intent.FLAG_ACTIVITY_MULTIPLE_TASK)
.startForResult(REQUEST_CODE);
}
@ -249,6 +249,7 @@ public class ProjectConfigActivity extends BaseActivity {
@SuppressLint("CheckResult")
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (resultCode != RESULT_OK) {
return;
}

View File

@ -280,6 +280,13 @@
android:checked="true"
android:text="@string/text_use_ml_kit_ocr" />
<CheckBox
android:id="@+id/use_onnx_runtime"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:checked="false"
android:text="@string/text_use_onnx_runtime" />
<CheckBox
android:id="@+id/use_tess_two"
android:layout_width="wrap_content"

View File

@ -423,6 +423,7 @@
<string name="text_use_opencv">使用OpenCV</string>
<string name="text_use_paddle_ocr">使用PaddleOCR</string>
<string name="text_use_ml_kit_ocr">使用ML-KitOCR</string>
<string name="text_use_onnx_runtime">使用OnnxRuntime</string>
<string name="text_use_tess_two">使用TessTwoOCR</string>
<plurals name="air_error_short_description" key="air_error_short_description">
<item quantity="one">描述至少要 %d 个字.</item>

View File

@ -1,2 +1,2 @@
configurations.maybeCreate("default")
artifacts.add("default", file('opencv-4.5.5.aar'))
artifacts.add("default", file('opencv-4.8.0.aar'))

View File

@ -50,11 +50,11 @@ def archives = [
// ],
]
/**
* opencv4.5.5 4.2.0 AutoJS中的版本不匹配会产生冲突
* opencv4.8.0 4.2.0 AutoJS中的版本不匹配会产生冲突
*/
def zipArchives = [
[
'src' : 'https://github.com/opencv/opencv/releases/download/4.5.5/opencv-4.5.5-android-sdk.zip',
'src' : 'https://github.com/opencv/opencv/releases/download/4.8.0/opencv-4.8.0-android-sdk.zip',
'dest': 'OpenCV'
]
]

View File

@ -76,6 +76,7 @@ dependencies {
api project(path: ':common')
api project(path: ':automator')
implementation 'com.rmtheis:tess-two:9.1.0'
implementation 'com.google.mlkit:text-recognition-chinese:16.0.0-beta6'
implementation 'com.google.mlkit:text-recognition-chinese:16.0.0'
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.15.1'
}

View File

@ -0,0 +1,306 @@
package com.stardust.autojs.onnx;
import android.os.Build;
import android.util.Log;
import com.google.gson.Gson;
import com.stardust.autojs.onnx.domain.DetectResult;
import com.stardust.autojs.onnx.domain.Detection;
import com.stardust.autojs.onnx.util.Letterbox;
import org.opencv.android.OpenCVLoader;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.NNAPIFlags;
import androidx.annotation.RequiresApi;
/**
* @author TonyJiangWJ
* @since 2023/8/20
*/
@RequiresApi(api = Build.VERSION_CODES.N)
public class YoloV8Predictor {
private static final String TAG = "YoloV8Predictor";
static {
OpenCVLoader.initDebug();
}
private final String modelPath;
private float confThreshold = 0.35F;
private float nmsThreshold = 0.55F;
private boolean tryNpu;
private Size shapeSize = new Size(640, 640);
private Letterbox letterbox;
private List<String> labels = new ArrayList<>();
private List<String> apiFlags = Arrays.asList("CPU_DISABLED");
public YoloV8Predictor(String modelPath) {
this.modelPath = modelPath;
}
public YoloV8Predictor(String modelPath, float confThreshold, float nmsThreshold) {
this.modelPath = modelPath;
this.confThreshold = confThreshold;
this.nmsThreshold = nmsThreshold;
}
public void setConfThreshold(float confThreshold) {
this.confThreshold = confThreshold;
}
public void setNmsThreshold(float nmsThreshold) {
this.nmsThreshold = nmsThreshold;
}
public void setLabels(List<String> labels) {
this.labels = labels;
}
public void setShapeSize(double width, double height) {
this.shapeSize = new Size(width, height);
}
public void setTryNpu(boolean tryNpu) {
this.tryNpu = tryNpu;
}
public void setApiFlags(List<String> apiFlags) {
this.apiFlags = apiFlags;
}
private OrtSession session;
private OrtEnvironment environment;
private void prepareSession() throws OrtException {
if (environment != null) {
return;
}
// 加载ONNX模型
environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
addNNApiProvider(sessionOptions);
session = environment.createSession(modelPath, sessionOptions);
// 输出基本信息
session.getInputInfo().keySet().forEach(x -> {
try {
System.out.println("input name = " + x);
System.out.println(session.getInputInfo().get(x).getInfo().toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
}
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
if (!tryNpu) {
return;
}
try {
List<NNAPIFlags> flags = new ArrayList<>();
if (apiFlags.contains("USE_FP16")) {
flags.add(NNAPIFlags.USE_FP16);
}
if (apiFlags.contains("USE_NCHW")) {
flags.add(NNAPIFlags.USE_NCHW);
}
if (apiFlags.contains("CPU_ONLY")) {
flags.add(NNAPIFlags.CPU_ONLY);
}
if (apiFlags.contains("CPU_DISABLED")) {
flags.add(NNAPIFlags.CPU_DISABLED);
}
Log.d(TAG, "addNNApiProvider: 当前启用nnapiFlags:" + new Gson().toJson(apiFlags));
sessionOptions.addNnapi(EnumSet.copyOf(flags));
Log.d(TAG, "prepareSession: 启用nnapi成功");
} catch (Exception e) {
Log.e(TAG, "prepareSession: 无法启用nnapi");
}
}
private HashMap<String, OnnxTensor> preprocessImage(Mat img) throws OrtException {
// 读取 image
Mat image = img.clone();
// 将四通道转换为三通道
if (image.channels() == 4) {
Imgproc.cvtColor(image, image, Imgproc.COLOR_RGBA2BGR);
}
Log.d(TAG, "preprocessImage: image's channels: " + image.channels());
Imgproc.cvtColor(image, image, Imgproc.COLOR_BGR2RGB);
// 更改 image 尺寸
letterbox = new Letterbox();
letterbox.setNewShape(this.shapeSize);
image = letterbox.letterbox(image);
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
// 将Mat对象的像素值赋值给Float[]对象
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
double[] pixel = image.get(j, i);
for (int k = 0; k < channels; k++) {
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
pixels[rows * cols * k + j * cols + i] = (float) pixel[k] / 255.0f;
}
}
}
image.release();
// 创建OnnxTensor对象
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
return stringOnnxTensorHashMap;
}
private List<Detection> postProcessOutput(OrtSession.Result output) throws OrtException {
float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
outputData = transposeMatrix(outputData);
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, outputData.length);
int label = argmax(conditionalProbabilities);
float conf = conditionalProbabilities[label];
if (conf < confThreshold) {
continue;
}
bbox[4] = conf;
// xywh to (x1, y1, x2, y2)
xywh2xyxy(bbox);
// skip invalid predictions
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {
continue;
}
class2Bbox.putIfAbsent(label, new ArrayList<>());
class2Bbox.get(label).add(bbox);
}
List<Detection> detections = new ArrayList<>();
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
int label = entry.getKey();
List<float[]> bboxes = entry.getValue();
bboxes = nonMaxSuppression(bboxes, nmsThreshold);
for (float[] bbox : bboxes) {
String labelString = "";
if (labels.size() - 1 < label) {
labelString = String.valueOf(label);
} else {
labelString = labels.get(label);
}
detections.add(new Detection(labelString, entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
}
}
return detections;
}
public List<DetectResult> predictYolo(String imagePath) throws OrtException {
return predictYolo(Imgcodecs.imread(imagePath));
}
public List<DetectResult> predictYolo(Mat image) throws OrtException {
prepareSession();
long start_time = System.currentTimeMillis();
// 运行推理
OrtSession.Result output = session.run(preprocessImage(image));
List<Detection> detections = postProcessOutput(output);
System.out.printf("time%d ms.\n", (System.currentTimeMillis() - start_time));
return detections.stream().map(detection -> new DetectResult(detection, letterbox))
.collect(Collectors.toList());
}
public static void xywh2xyxy(float[] bbox) {
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
bbox[0] = x - w * 0.5f;
bbox[1] = y - h * 0.5f;
bbox[2] = x + w * 0.5f;
bbox[3] = y + h * 0.5f;
}
public static float[][] transposeMatrix(float[][] m) {
float[][] temp = new float[m[0].length][m.length];
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[0].length; j++) {
temp[j][i] = m[i][j];
}
}
return temp;
}
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
List<float[]> bestBboxes = new ArrayList<>();
bboxes.sort(Comparator.comparing(a -> a[4]));
while (!bboxes.isEmpty()) {
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
bestBboxes.add(bestBbox);
bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList());
}
return bestBboxes;
}
public static float computeIOU(float[] box1, float[] box2) {
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
float left = Math.max(box1[0], box2[0]);
float top = Math.max(box1[1], box2[1]);
float right = Math.min(box1[2], box2[2]);
float bottom = Math.min(box1[3], box2[3]);
float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0);
float unionArea = area1 + area2 - interArea;
return Math.max(interArea / unionArea, 1e-8f);
}
//返回最大值的索引
public static int argmax(float[] a) {
float re = -Float.MAX_VALUE;
int arg = -1;
for (int i = 0; i < a.length; i++) {
if (a[i] >= re) {
re = a[i];
arg = i;
}
}
return arg;
}
}

View File

@ -0,0 +1,89 @@
package com.stardust.autojs.onnx.domain;
import com.stardust.autojs.onnx.util.Letterbox;
/**
* @author TonyJiangWJ
* @since 2023/8/20
*/
public class DetectResult {
private String label;
private Integer clsId;
private double left;
private double top;
private double right;
private double bottom;
private float confidence;
public DetectResult() {
}
public DetectResult(Detection detection, Letterbox letterbox) {
this.label = detection.label;
this.confidence = detection.confidence;
double dw = letterbox.getDw();
double dh = letterbox.getDh();
double ratio = letterbox.getRatio();
left = (detection.getBbox()[0] - dw) / ratio;
right = (detection.getBbox()[2] - dw) / ratio;
top = (detection.getBbox()[1] - dh) / ratio;
bottom = (detection.getBbox()[3] - dh) / ratio;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public Integer getClsId() {
return clsId;
}
public void setClsId(Integer clsId) {
this.clsId = clsId;
}
public double getLeft() {
return left;
}
public void setLeft(double left) {
this.left = left;
}
public double getTop() {
return top;
}
public void setTop(double top) {
this.top = top;
}
public double getRight() {
return right;
}
public void setRight(double right) {
this.right = right;
}
public double getBottom() {
return bottom;
}
public void setBottom(double bottom) {
this.bottom = bottom;
}
public float getConfidence() {
return confidence;
}
public void setConfidence(float confidence) {
this.confidence = confidence;
}
}

View File

@ -0,0 +1,62 @@
package com.stardust.autojs.onnx.domain;
/**
* @author TonyJiangWJ
* @since 2023/8/20
*/
public class Detection {
public String label;
private Integer clsId;
public float[] bbox;
public float confidence;
public Detection(String label,Integer clsId, float[] bbox, float confidence){
this.clsId = clsId;
this.label = label;
this.bbox = bbox;
this.confidence = confidence;
}
public Detection(){
}
public Integer getClsId() {
return clsId;
}
public void setClsId(Integer clsId) {
this.clsId = clsId;
}
public String getLabel() {
return label;
}
public void setLabel(String label) {
this.label = label;
}
public float[] getBbox() {
return bbox;
}
public void setBbox(float[] bbox) {
}
@Override
public String toString() {
return " label="+label +
" \t clsId="+clsId +
" \t x0="+bbox[0] +
" \t y0="+bbox[1] +
" \t x1="+bbox[2] +
" \t y1="+bbox[3] +
" \t score="+confidence;
}
}

View File

@ -0,0 +1,87 @@
package com.stardust.autojs.onnx.util;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
/**
* @author TonyJiangWJ
* @since 2023/8/20
*/
public class Letterbox {
private Size newShape = new Size(640, 640);
private final double[] color = new double[]{114, 114, 114};
private final Boolean auto = false;
private final Boolean scaleUp = true;
private Integer stride = 32;
private double ratio;
private double dw;
private double dh;
public double getRatio() {
return ratio;
}
public double getDw() {
return dw;
}
public Integer getWidth() {
return (int) this.newShape.width;
}
public Integer getHeight() {
return (int) this.newShape.height;
}
public double getDh() {
return dh;
}
public void setNewShape(Size newShape) {
this.newShape = newShape;
}
public void setStride(Integer stride) {
this.stride = stride;
}
public Mat letterbox(Mat im) { // 调整图像大小和填充图像使满足步长约束并记录参数
// 当前形状 [height, width]
int[] shape = {im.rows(), im.cols()};
// Scale ratio (new / old)
double r = Math.min(this.newShape.height / shape[0], this.newShape.width / shape[1]);
// 仅缩小不扩大一且为了mAP
if (!this.scaleUp) {
r = Math.min(r, 1.0);
}
// Compute padding
Size newUnpad = new Size(Math.round(shape[1] * r), Math.round(shape[0] * r));
// wh 填充
double dw = this.newShape.width - newUnpad.width, dh = this.newShape.height - newUnpad.height;
// 最小矩形
if (this.auto) {
dw = dw % this.stride;
dh = dh % this.stride;
}
// 填充的时候两边都填充一半使图像居于中心
dw /= 2;
dh /= 2;
// resize
if (shape[1] != newUnpad.width || shape[0] != newUnpad.height) {
Imgproc.resize(im, im, newUnpad, 0, 0, Imgproc.INTER_LINEAR);
}
int top = (int) Math.round(dh - 0.1), bottom = (int) Math.round(dh + 0.1);
int left = (int) Math.round(dw - 0.1), right = (int) Math.round(dw + 0.1);
// 将图像填充为正方形
Core.copyMakeBorder(im, im, top, bottom, left, right, Core.BORDER_CONSTANT, new org.opencv.core.Scalar(this.color));
this.ratio = r;
this.dh = dh;
this.dw = dw;
return im;
}
}

View File

@ -15,6 +15,7 @@ android {
signingConfigs {
if (file(homePath + '/auto-js-t-pkcs12.jks').exists()) {
signSupport = true
release {
storeFile file(homePath + '/auto-js-t-pkcs12.jks')
storePassword storePasswd