YOLO相关代码使用抽象工厂重构提高后续扩展和可维护性

This commit is contained in:
TonyJiangWJ 2025-01-05 22:32:51 +08:00
parent 020258392f
commit 5aa55cda3e
28 changed files with 893 additions and 942 deletions

View File

@ -1,5 +1,8 @@
NCNN和PaddleOCR同时使用时有兼容性问题导致无法正常运行甚至闪退请勿同时使用
NCNN和PaddleOCR同时使用时,
有兼容性问题导致无法正常运行甚至闪退,请勿同时使用
该问题暂时无法解决因此如果需要使用PaddleOcr时 请使用onnx或者使用mlkitocr进行文字识别
该问题暂时无法解决因此如果需要使用PaddleOcr时
请使用onnx或者使用mlkitocr进行文字识别
建议在依赖性能的情况下使用ncnn同时使用mlkitocr进行文字识别
建议在依赖性能的情况下使用ncnn
同时使用mlkitocr进行文字识别

View File

@ -1,6 +1,6 @@
const img = images.read("./test.png")
console.show()
setTimeout(() -> console.hide(), 15000)
setTimeout(() => console.hide(), 15000)
let cpuThreadNum = 4
// PaddleOCR 移动端提供了两种模型ocr_v3_for_cpu与ocr_v3_for_cpu(slim),此选项用于选择加载的模型,默认true使用v3的slim版(速度更快)false使用v3的普通版(准确率更高)
let useSlim = true
@ -14,8 +14,8 @@ let result = $ocr.detect(img, { cpuThreadNum, useSlim })
img.recycle()
log('slim识别耗时' + (new Date() - start) + 'ms')
let model_path = '/sdcard/脚本/best.bin'
let param_path = '/sdcard/脚本/best.param'
let model_path = '/sdcard/脚本/manor.bin'
let param_path = '/sdcard/脚本/manor.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()

View File

@ -1,8 +1,7 @@
console.show()
setTimeout(() -> console.hide(), 15000)
setTimeout(() -> console.hide(), 15000)
let model_path = '/sdcard/脚本/best.bin'
let param_path = '/sdcard/脚本/best.param'
setTimeout(() => console.hide(), 15000)
let model_path = '/sdcard/脚本/manor.bin'
let param_path = '/sdcard/脚本/manor.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 134 KiB

View File

@ -1,39 +0,0 @@
let model_path = '/sdcard/脚本/yolov8n.bin'
let param_path = '/sdcard/脚本/yolov8n.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()
}
console.show()
setTimeout(() -> console.hide(), 15000)
let yoloInit = $yolo.init({
type: 'ncnn',
useGpu: true,
paramPath: files.path(param_path),
binPath: files.path(model_path),
imageSize: 480,
labels: [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush"
]
})
if (!yoloInit) {
toast('初始化失败')
exit()
}
const img = images.read("./bus.jpg")
let start = new Date()
const result = $yolo.forward(img)
toastLog('ncnn cost: ' + (new Date() - start) + 'ms')
log('predict result:' + JSON.stringify(result, null, 4))
img.recycle()

View File

@ -1,38 +0,0 @@
let model_path = '/sdcard/脚本/yolov8n.bin'
let param_path = '/sdcard/脚本/yolov8n.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()
}
console.show()
setTimeout(() -> console.hide(), 15000)
let yoloInit = $yolo.init({
type: 'ncnn',
paramPath: files.path(param_path),
binPath: files.path(model_path),
imageSize: 480,
labels: [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush"
]
})
if (!yoloInit) {
toast('初始化失败')
exit()
}
const img = images.read("./bus.jpg")
let start = new Date()
const result = $yolo.forward(img)
toastLog('ncnn cost: ' + (new Date() - start) + 'ms')
log('predict result:' + JSON.stringify(result, null, 4))
img.recycle()

View File

@ -1,11 +1,11 @@
let model_path = '/sdcard/脚本/best.bin'
let param_path = '/sdcard/脚本/best.param'
let model_path = '/sdcard/脚本/manor.bin'
let param_path = '/sdcard/脚本/manor.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()
}
console.show()
setTimeout(() -> console.hide(), 15000)
setTimeout(() => console.hide(), 15000)
let yoloInit = $yolo.init({
type: 'ncnn',
paramPath: files.path(param_path),

View File

@ -4,7 +4,7 @@ if (!files.exists(model_path)) {
exit()
}
console.show()
setTimeout(() -> console.hide(), 15000)
setTimeout(() => console.hide(), 15000)
let yoloInit = $yolo.init({
type: 'onnx',
modelPath: files.path(model_path),

View File

@ -1,307 +0,0 @@
// 杀死当前同名脚本 see AutoScriptBase/lib/killMyDuplicator
(() => { let g = engines.myEngine(); var e = engines.all(), n = e.length; let r = g.getSource() + ""; 1 < n && e.forEach(e => { var n = e.getSource() + ""; g.id !== e.id && n == r && e.forceStop() }) })();
if (!requestScreenCapture()) {
toastLog('请求截图权限失败')
exit()
}
let onnxInstance = null
let ncnnInstance = null
let currentType = 'ncnn'
initYoloInstances()
let yoloInstance = {
ncnn: ncnnInstance,
onnx: onnxInstance,
}
// 识别结果和截图信息
let result = []
let img = null
let running = true
let capturing = false
/**
* 截图并识别OCR文本信息
*/
function captureAndDetect () {
capturing = true
img = captureScreen()
if (!img) {
toastLog('截图失败')
}
let start = new Date()
result = yoloInstance[currentType].forward(img)
console.verbose('识别结果:' + JSON.stringify(result))
toastLog('耗时' + (new Date() - start) + 'ms')
img && img.recycle()
capturing = false
}
// captureAndDetect()
// 获取状态栏高度
let offset = -getStatusBarHeightCompat()
// 绘制识别结果
let window = floaty.rawWindow(
<canvas id="canvas" layout_weight="1" />
);
// 设置悬浮窗位置
ui.post(() => {
window.setPosition(0, offset)
window.setSize(device.width, device.height)
window.setTouchable(false)
})
// 操作按钮
let clickButtonWindow = floaty.rawWindow(
<vertical>
<button id="changeYoloType" text="当前ncnn" />
<button id="captureAndDetect" text="截图识别" />
<button id="closeBtn" text="退出" />
</vertical>
);
ui.run(function () {
clickButtonWindow.setPosition(device.width / 2 - ~~(clickButtonWindow.getWidth() / 2), device.height * 0.65)
})
// 切换类型
clickButtonWindow.changeYoloType.click(function () {
threads.start(function () {
changeYoloType()
ui.run(function () {
if (currentType === 'onnx') {
clickButtonWindow.changeYoloType.setText('当前onnx')
} else {
clickButtonWindow.changeYoloType.setText('当前ncnn')
}
})
})
})
// 点击识别
clickButtonWindow.captureAndDetect.click(function () {
if (capturing) {
return
}
result = []
let oldPosition = {
x: clickButtonWindow.getX(),
y: clickButtonWindow.getY(),
}
ui.run(function () {
clickButtonWindow.setPosition(device.width, device.height)
})
setTimeout(() => {
captureAndDetect()
ui.run(function () {
clickButtonWindow.setPosition(oldPosition.x, oldPosition.y)
})
}, 500)
})
// 点击关闭
clickButtonWindow.closeBtn.setOnTouchListener(new TouchController(clickButtonWindow, () => {
exit()
}).createListener())
let Typeface = android.graphics.Typeface
let paint = new Paint()
paint.setStrokeWidth(1)
paint.setTypeface(Typeface.DEFAULT_BOLD)
paint.setTextAlign(Paint.Align.LEFT)
paint.setAntiAlias(true)
paint.setStrokeJoin(Paint.Join.ROUND)
paint.setDither(true)
window.canvas.on('draw', function (canvas) {
if (!running || capturing) {
return
}
// 清空内容
canvas.drawColor(0xFFFFFF, android.graphics.PorterDuff.Mode.CLEAR)
if (result && result.length > 0) {
for (let i = 0; i < result.length; i++) {
let detectResult = result[i]
drawRectAndText(detectResult.label, detectResult.bounds, '#00ff00', canvas, paint)
}
}
})
setInterval(() => { }, 10000)
events.on('exit', () => {
// 标记停止 避免canvas导致闪退
running = false
// 撤销监听
window.canvas.removeAllListeners()
// 回收图片
img && img.recycle()
})
/**
* 绘制文本和方框
*
* @param {*} desc
* @param {*} rect
* @param {*} colorStr
* @param {*} canvas
* @param {*} paint
*/
function drawRectAndText (desc, rect, colorStr, canvas, paint) {
let color = colors.parseColor(colorStr)
paint.setStrokeWidth(1)
paint.setStyle(Paint.Style.STROKE)
// 反色
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
canvas.drawRect(rect, paint)
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
paint.setStrokeWidth(1)
paint.setTextSize(20)
paint.setStyle(Paint.Style.FILL)
canvas.drawText(desc, rect.left, rect.top, paint)
paint.setTextSize(10)
paint.setStrokeWidth(1)
paint.setARGB(255, 0, 0, 0)
}
/**
* 获取状态栏高度
*
* @returns
*/
function getStatusBarHeightCompat () {
let result = 0
let resId = context.getResources().getIdentifier("status_bar_height", "dimen", "android")
if (resId > 0) {
result = context.getResources().getDimensionPixelOffset(resId)
}
if (result <= 0) {
result = context.getResources().getDimensionPixelOffset(R.dimen.dimen_25dp)
}
return result
}
function initYoloInstances () {
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
if (!files.exists(onnx_model_path)) {
toastLog('请确认已下载了onnx模型文件')
exit()
}
let model_path = '/sdcard/脚本/best.bin'
let param_path = '/sdcard/脚本/best.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
exit()
}
let ncnnInit = $yolo.init({
type: 'ncnn',
paramPath: files.path(param_path),
binPath: files.path(model_path),
imageSize: 480,
labels: [
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
]
})
if (ncnnInit) {
ncnnInstance = $yolo.getInstance()
} else {
toastLog('ncnn初始化失败')
}
let onnxInit = $yolo.init({
type: 'onnx',
modelPath: files.path(onnx_model_path),
imageSize: 480,
labels: [
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken', 'bring_back',
'leave_msg', 'speedup_eating',
]
})
if (onnxInit) {
onnxInstance = $yolo.getInstance()
} else {
toastLog('onnx初始化失败')
}
}
function changeYoloType () {
let options = ["ncnn", "onnx"]
let idx = dialogs.singleChoice("请选择YOLO推理类型", options, options.indexOf(currentType))
let targetType = options[idx]
toast("选择了: " + targetType)
if (!yoloInstance[targetType]) {
toastLog('目标类型未能初始化:' + targetType)
return
}
currentType = targetType
}
function TouchController (buttonWindow, handleClick, handleDown, handleUp) {
this.eventStartX = null
this.eventStartY = null
this.windowStartX = buttonWindow.getX()
this.windowStartY = buttonWindow.getY()
this.eventKeep = false
this.eventMoving = false
this.touchDownTime = new Date().getTime()
this.createListener = function () {
let _this = this
return new android.view.View.OnTouchListener((view, event) => {
try {
switch (event.getAction()) {
case event.ACTION_DOWN:
handleDown && handleDown()
_this.eventStartX = event.getRawX();
_this.eventStartY = event.getRawY();
_this.windowStartX = buttonWindow.getX();
_this.windowStartY = buttonWindow.getY();
_this.eventKeep = true; //按下,开启计时
_this.touchDownTime = new Date().getTime()
break;
case event.ACTION_MOVE:
var sx = event.getRawX() - _this.eventStartX;
var sy = event.getRawY() - _this.eventStartY;
if (!_this.eventMoving && _this.eventKeep && getDistance(sx, sy) >= 10) {
_this.eventMoving = true;
}
if (_this.eventMoving && _this.eventKeep) {
ui.post(() => {
buttonWindow.setPosition(_this.windowStartX + sx, _this.windowStartY + sy);
})
}
break;
case event.ACTION_UP:
handleUp && handleUp()
if (!_this.eventMoving && _this.eventKeep && _this.touchDownTime > new Date().getTime() - 1000) {
handleClick && handleClick()
}
_this.eventKeep = false;
_this.touchDownTime = 0;
_this.eventMoving = false;
break;
}
} catch (e) {
console.error('异常' + e)
}
return true;
})
}
}
function getDistance (dx, dy) {
return Math.sqrt(Math.pow(dx, 2) + Math.pow(dy, 2));
}

View File

@ -7,34 +7,36 @@ if (!requestScreenCapture()) {
}
let onnxInstance = null
let ncnnInstance = null
let currentType = 'ncnn'
let initSuccess = false
initYoloInstances()
let yoloInstance = {
ncnn: ncnnInstance,
onnx: onnxInstance,
}
// 识别结果和截图信息
let result = []
let img = null
let running = true
let capturing = false
let cost = 0
/**
* 截图并识别OCR文本信息
*/
function captureAndDetect () {
if (!initSuccess) {
toastLog('当前推理模型未能初始化,请选择另一个')
return
}
capturing = true
img = captureScreen()
if (!img) {
toastLog('截图失败')
}
let start = new Date()
result = $yolo.forward(img)
result = yoloInstance[currentType].forward(img)
console.verbose('识别结果:' + JSON.stringify(result))
toastLog('耗时' + (new Date() - start) + 'ms')
cost = (new Date() - start)
toastLog('耗时' + cost + 'ms')
img && img.recycle()
capturing = false
}
@ -127,6 +129,10 @@ window.canvas.on('draw', function (canvas) {
drawRectAndText(detectResult.label, detectResult.bounds, '#00ff00', canvas, paint)
}
}
drawText('请打开支付宝蚂蚁庄园界面进行识别', 100, device.height - 300, '#00ff00', canvas, paint)
if (cost > 0) {
drawText('识别耗时:' + cost + 'ms', 100, device.height - 250, '#00ff00', canvas, paint)
}
})
setInterval(() => { }, 10000)
@ -156,16 +162,42 @@ function drawRectAndText (desc, rect, colorStr, canvas, paint) {
// 反色
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
canvas.drawRect(rect, paint)
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
paint.setStrokeWidth(1)
paint.setTextSize(20)
paint.setStyle(Paint.Style.FILL)
canvas.drawText(desc, rect.left + 1, rect.top + 2, paint)
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
canvas.drawText(desc, rect.left, rect.top, paint)
paint.setTextSize(10)
paint.setStrokeWidth(1)
paint.setARGB(255, 0, 0, 0)
}
/**
* 绘制文本
*
* @param {*} desc
* @param {*} left
* @param {*} top
* @param {*} colorStr
* @param {*} canvas
* @param {*} paint
*/
function drawText (desc, left, top, colorStr, canvas, paint) {
let color = colors.parseColor(colorStr)
paint.setStrokeWidth(1)
paint.setStyle(Paint.Style.STROKE)
paint.setStrokeWidth(1)
paint.setTextSize(30)
paint.setStyle(Paint.Style.FILL)
// 反色 阴影
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
canvas.drawText(desc, left + 1, top + 2, paint)
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
canvas.drawText(desc, left, top, paint)
}
/**
* 获取状态栏高度
*
@ -185,61 +217,55 @@ function getStatusBarHeightCompat () {
function initYoloInstances () {
if (initSuccess) {
$yolo.release()
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
if (!files.exists(onnx_model_path)) {
toastLog('请确认已下载了onnx模型文件')
exit()
}
initSuccess = false
if (currentType == 'ncnn') {
let model_path = '/sdcard/脚本/best.bin'
let param_path = '/sdcard/脚本/best.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了模型文件')
return
}
let ncnnInit = $yolo.init({
type: 'ncnn',
paramPath: files.path(param_path),
binPath: files.path(model_path),
imageSize: 480,
labels: [
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
]
})
if (ncnnInit) {
initSuccess = true
} else {
toastLog('ncnn初始化失败')
}
let model_path = '/sdcard/脚本/manor.bin'
let param_path = '/sdcard/脚本/manor.param'
if (!files.exists(model_path) || !files.exists(param_path)) {
toastLog('请确认已下载了ncnn模型文件')
exit()
}
let ncnnInit = $yolo.init({
type: 'ncnn',
paramPath: files.path(param_path),
binPath: files.path(model_path),
imageSize: 480,
// ncnn 版本必须填写labels
labels: [
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
]
})
if (ncnnInit) {
ncnnInstance = $yolo.getInstance()
} else {
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
if (!files.exists(onnx_model_path)) {
toastLog('请确认已下载了onnx模型文件')
return
}
let onnxInit = $yolo.init({
type: 'onnx',
modelPath: files.path(onnx_model_path),
imageSize: 480,
labels: [
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken', 'bring_back',
'leave_msg', 'speedup_eating',
]
})
if (onnxInit) {
initSuccess = true
} else {
toastLog('onnx初始化失败')
}
toastLog('ncnn初始化失败')
}
let onnxInit = $yolo.init({
type: 'onnx',
modelPath: files.path(onnx_model_path),
imageSize: 480,
// onnx版本可以不填写labels可以通过onnx模型自动提取当然也可以自己提供比如映射成中文等
labels: [
'摆摊按钮', '收集金币', '收蛋', '领饲料', '去做饭', '倒计时', '捐蛋',
'eating_chicken', 'employ', '空摊位', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
'hungry_chicken', '道具', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
]
})
if (onnxInit) {
onnxInstance = $yolo.getInstance()
} else {
toastLog('onnx初始化失败')
}
}
@ -249,8 +275,11 @@ function changeYoloType () {
let idx = dialogs.singleChoice("请选择YOLO推理类型", options, options.indexOf(currentType))
let targetType = options[idx]
toast("选择了: " + targetType)
if (!yoloInstance[targetType]) {
toastLog('目标类型未能初始化:' + targetType)
return
}
currentType = targetType
initYoloInstances()
}
function TouchController (buttonWindow, handleClick, handleDown, handleUp) {

View File

@ -1,9 +1,7 @@
存放路径 下载地址
存放路径 下载地址https://pan.quark.cn/s/7242eae30941
ncnn:
/sdcard/脚本/yolov8n.param
/sdcard/脚本/yolov8n.bin
/sdcard/脚本/best.param
/sdcard/脚本/best.bin
/sdcard/脚本/manor.param
/sdcard/脚本/manor.bin
onnx:
/sdcard/脚本/manor_lite.onnx

View File

@ -25,3 +25,5 @@
#-renamesourcefileattribute SourceFile
#-keep public class com.stardust.autojs.onnx.YoloV8Predictor
#-keepnames class com.stardust.autojs.onnx.YoloV8Predictor
#-keep public class com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor
#-keepnames class com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor

View File

@ -103,7 +103,6 @@ public class Mat extends org.opencv.core.Mat implements ResourceMonitor.Resource
}
mReleased = true;
}
super.finalize();
}
@Override

View File

@ -1,374 +1,25 @@
package com.stardust.autojs.onnx;
import android.os.Build;
import android.util.Log;
import com.google.gson.Gson;
import com.stardust.autojs.onnx.domain.DetectResult;
import com.stardust.autojs.onnx.domain.Detection;
import com.stardust.autojs.onnx.util.Letterbox;
import com.stardust.autojs.runtime.api.YoloPredictor;
import com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.NNAPIFlags;
import androidx.annotation.RequiresApi;
/**
* @author TonyJiangWJ
* @since 2023/8/20
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
* 适配旧版本脚本
*/
@RequiresApi(api = Build.VERSION_CODES.N)
public class YoloV8Predictor extends YoloPredictor {
private static final String TAG = "YoloV8Predictor";
private static final Pattern IMG_SIZE_PATTERN = Pattern.compile("\\[(\\d+), \\d+]");
private static final Pattern LABEL_PATTERN = Pattern.compile("'([^']*)'");
private final String modelPath;
private boolean tryNpu;
private Size shapeSize = new Size(640, 640);
private Letterbox letterbox;
private List<String> apiFlags = Arrays.asList("CPU_DISABLED");
public class YoloV8Predictor extends OnnxYoloV8Predictor {
public YoloV8Predictor(String modelPath) {
this.modelPath = modelPath;
init = true;
super(modelPath);
}
public YoloV8Predictor(String modelPath, float confThreshold, float nmsThreshold) {
this.modelPath = modelPath;
this.confThreshold = confThreshold;
this.nmsThreshold = nmsThreshold;
init = true;
super(modelPath, confThreshold, nmsThreshold);
}
public void setShapeSize(double width, double height) {
this.shapeSize = new Size(width, height);
}
public void setTryNpu(boolean tryNpu) {
this.tryNpu = tryNpu;
}
public void setApiFlags(List<String> apiFlags) {
this.apiFlags = apiFlags;
}
private OrtSession session;
private OrtEnvironment environment;
private void prepareSession() throws OrtException {
if (environment != null) {
return;
}
// 加载ONNX模型
environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
addNNApiProvider(sessionOptions);
session = environment.createSession(modelPath, sessionOptions);
// 输出基本信息
session.getInputInfo().keySet().forEach(x -> {
try {
System.out.println("input name = " + x);
System.out.println(session.getInputInfo().get(x).getInfo().toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
// 如果入参labels无效或未定义使用模型内置labels
if (labels == null || labels.size() == 0) {
labels = initLabels(session);
}
initShapeSize(session);
}
private List<String> initLabels(OrtSession session) throws OrtException {
String meteStr = session.getMetadata().getCustomMetadata().get("names");
if (meteStr == null) {
Log.d(TAG, "initLabels: 读取names失败 无法自动修正labels");
return Collections.emptyList();
}
String[] labels = new String[meteStr.split(",").length];
Matcher matcher = LABEL_PATTERN.matcher(meteStr);
int h = 0;
while (matcher.find()) {
labels[h] = matcher.group(1);
h++;
}
return Arrays.asList(labels);
}
private void initShapeSize(OrtSession session) throws OrtException {
String meteStr = session.getMetadata().getCustomMetadata().get("imgsz");
Log.d(TAG, "initShapeSize: " + meteStr);
if (meteStr == null) {
Log.d(TAG, "initShapeSize: 读取imgsz失败 无法自动修正输入大小");
return;
}
Matcher matcher = IMG_SIZE_PATTERN.matcher(meteStr);
if (matcher.find()) {
String shapeSize = matcher.group(1);
if (shapeSize == null) {
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
return;
}
this.shapeSize = new Size(Double.parseDouble(shapeSize), Double.parseDouble(shapeSize));
Log.d(TAG, "set shape size: " + shapeSize);
} else {
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
}
}
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
if (!tryNpu) {
return;
}
try {
List<NNAPIFlags> flags = new ArrayList<>();
if (apiFlags.contains("USE_FP16")) {
flags.add(NNAPIFlags.USE_FP16);
}
if (apiFlags.contains("USE_NCHW")) {
flags.add(NNAPIFlags.USE_NCHW);
}
if (apiFlags.contains("CPU_ONLY")) {
flags.add(NNAPIFlags.CPU_ONLY);
}
if (apiFlags.contains("CPU_DISABLED")) {
flags.add(NNAPIFlags.CPU_DISABLED);
}
Log.d(TAG, "addNNApiProvider: 当前启用nnapiFlags:" + new Gson().toJson(apiFlags));
sessionOptions.addNnapi(EnumSet.copyOf(flags));
Log.d(TAG, "prepareSession: 启用nnapi成功");
} catch (Exception e) {
Log.e(TAG, "prepareSession: 无法启用nnapi");
}
}
private HashMap<String, OnnxTensor> preprocessImage(Mat img) throws OrtException {
// 读取 image
Mat image = img.clone();
// 将四通道转换为三通道
if (image.channels() == 4) {
Imgproc.cvtColor(image, image, Imgproc.COLOR_RGBA2RGB);
}
Log.d(TAG, "preprocessImage: image's channels: " + image.channels());
// 更改 image 尺寸
letterbox = new Letterbox();
letterbox.setNewShape(this.shapeSize);
image = letterbox.letterbox(image);
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
// 转换Mat对象的数据类型为CV_64F即64位浮点型
Mat convertedImage = new Mat();
image.convertTo(convertedImage, CvType.CV_64F);
// 获取整个像素数据
double[] pixelData = new double[rows * cols * channels];
convertedImage.get(0, 0, pixelData);
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
for (int k = 0; k < channels; k++) {
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
// 重新组织内存访问模式提高缓存效率
pixels[k * rows * cols + i * cols + j] = (float) (pixelData[(i * cols + j) * channels + k] / 255.0);
}
}
}
image.release();
convertedImage.release();
// 创建OnnxTensor对象
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
return stringOnnxTensorHashMap;
}
private List<Detection> postProcessOutput(OrtSession.Result output) throws OrtException {
float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
outputData = transposeMatrix(outputData);
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
int label = argmax(bbox, 4); // 直接在原数组上进行操作
float conf = bbox[label + 4];
if (conf < confThreshold) {
continue;
}
bbox[4] = conf;
// xywh to (x1, y1, x2, y2)
xywh2xyxy(bbox);
// skip invalid predictions
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {
continue;
}
class2Bbox.computeIfAbsent(label, k -> new ArrayList<>()).add(bbox);
}
List<Detection> detections = new ArrayList<>();
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
int label = entry.getKey();
List<float[]> bboxes = entry.getValue();
bboxes = nonMaxSuppression(bboxes, nmsThreshold);
for (float[] bbox : bboxes) {
String labelString = "";
if (labels.size() - 1 < label) {
labelString = String.valueOf(label);
} else {
labelString = labels.get(label);
}
detections.add(new Detection(labelString, entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
}
}
return detections;
}
public List<DetectResult> predictYolo(String imagePath) throws OrtException {
return predictYolo(Imgcodecs.imread(imagePath));
}
public List<DetectResult> predictYolo(Mat image) throws OrtException {
prepareSession();
long start_time = System.currentTimeMillis();
Map<String, OnnxTensor> inputMap = preprocessImage(image);
// 运行推理
try (OrtSession.Result output = session.run(inputMap)) {
Log.d(TAG, "predictYolo: onnx run cost " + (System.currentTimeMillis() - start_time) + "ms");
List<Detection> detections = postProcessOutput(output);
Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
return detections.stream().map(detection -> new DetectResult(detection, letterbox))
.collect(Collectors.toList());
} finally {
// 释放资源
inputMap.values().forEach(OnnxTensor::close);
}
}
public static void xywh2xyxy(float[] bbox) {
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
bbox[0] = x - w * 0.5f;
bbox[1] = y - h * 0.5f;
bbox[2] = x + w * 0.5f;
bbox[3] = y + h * 0.5f;
}
public static float[][] transposeMatrix(float[][] m) {
float[][] temp = new float[m[0].length][m.length];
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[0].length; j++) {
temp[j][i] = m[i][j];
}
}
return temp;
}
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
long start = System.currentTimeMillis();
List<float[]> bestBboxes = new ArrayList<>();
bboxes.sort(Comparator.comparing(a -> a[4]));
while (!bboxes.isEmpty()) {
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
bestBboxes.add(bestBbox);
bboxes.removeIf(bbox -> computeIOU(bbox, bestBbox) >= iouThreshold);
}
Log.d(TAG, "nonMaxSuppression: cost " + (System.currentTimeMillis() - start) + "ms");
return bestBboxes;
}
public static float computeIOU(float[] box1, float[] box2) {
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
float left = Math.max(box1[0], box2[0]);
float top = Math.max(box1[1], box2[1]);
float right = Math.min(box1[2], box2[2]);
float bottom = Math.min(box1[3], box2[3]);
// 计算交集区域的宽度和高度
float width = Math.max(right - left, 0);
float height = Math.max(bottom - top, 0);
// 计算交集面积和并集面积
float interArea = width * height;
float unionArea = area1 + area2 - interArea;
// 计算交并比
return Math.max(interArea / unionArea, 1e-8f);
}
//返回最大值的索引
// 优化后的 argmax 函数
public static int argmax(float[] a, int start) {
float re = -Float.MAX_VALUE;
int arg = -1;
for (int i = start; i < a.length; i++) {
if (a[i] >= re) {
re = a[i];
arg = i - start;
}
}
return arg;
}
@Override
public void release() {
if (session != null) {
try {
session.close();
session = null;
} catch (OrtException e) {
Log.e(TAG, "close session failed" + e);
}
environment.close();
environment = null;
}
}
}

View File

@ -1,23 +1,15 @@
package com.stardust.autojs.runtime.api;
import android.media.Image;
import android.os.Build;
import android.util.Log;
import com.stardust.autojs.core.image.ImageWrapper;
import com.stardust.autojs.ncnn.NcnnYoloV8Predictor;
import com.stardust.autojs.onnx.YoloV8Predictor;
import com.stardust.autojs.onnx.domain.DetectResult;
import com.stardust.autojs.runtime.ScriptRuntime;
import com.stardust.autojs.yolo.ModelInitParams;
import com.stardust.autojs.yolo.YoloInstance;
import com.stardust.autojs.yolo.ncnn.NcnnInitParams;
import com.stardust.autojs.yolo.ncnn.NcnnYoloInstanceFactory;
import com.stardust.autojs.yolo.onnx.OnnxYoloInstanceFactory;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import java.util.Collections;
import java.util.List;
import ai.onnxruntime.OrtException;
import androidx.annotation.RequiresApi;
/**
@ -28,96 +20,26 @@ import androidx.annotation.RequiresApi;
public class Yolo {
private static final String TAG = "Yolo";
public YoloInstance createNcnn(String paramPath, String binPath, List<String> labels, Integer imageSize, boolean useGpu) {
return new YoloInstance() {
private NcnnYoloV8Predictor ncnnYoloV8 = new NcnnYoloV8Predictor(paramPath, binPath, labels);
{
ncnnYoloV8.setShapeSize(imageSize);
ncnnYoloV8.setUseGpu(useGpu);
Log.d(TAG, "ncnnYoloV8 instance initializer: " + ncnnYoloV8.init());
}
@Override
public YoloPredictor getPredictor() {
return ncnnYoloV8;
}
@Override
public List<DetectResult> predictYolo(Mat image) {
return ncnnYoloV8.predictYolo(image);
}
};
}
private final NcnnYoloInstanceFactory ncnnFactory = new NcnnYoloInstanceFactory();
private final OnnxYoloInstanceFactory onnxFactory = new OnnxYoloInstanceFactory();
public YoloInstance createOnnx(String modelPath, List<String> labels, Integer imageSize) {
return new YoloInstance() {
private YoloV8Predictor onnxYoloV8 = new YoloV8Predictor(modelPath);
{
onnxYoloV8.setLabels(labels);
onnxYoloV8.setShapeSize(imageSize, imageSize);
}
@Override
public YoloPredictor getPredictor() {
return onnxYoloV8;
}
@Override
public List<DetectResult> predictYolo(Mat image) {
try {
return onnxYoloV8.predictYolo(image);
} catch (OrtException e) {
return Collections.emptyList();
}
}
};
ModelInitParams params = new ModelInitParams();
params.setModelPath(modelPath);
params.setLabels(labels);
params.setImageSize(imageSize);
return onnxFactory.createInstance(params);
}
public static abstract class YoloInstance {
public abstract YoloPredictor getPredictor();
public abstract List<DetectResult> predictYolo(Mat image);
public void setConfThreshold(float confThreshold) {
getPredictor().setConfThreshold(confThreshold);
}
public void setNmsThreshold(float nmsThreshold) {
getPredictor().setNmsThreshold(nmsThreshold);
}
public boolean isInit() {
return getPredictor().isInit();
}
public void release() {
getPredictor().release();
}
public List<DetectResult> captureAndPredict(ScriptRuntime runtime, Rect rect) {
Images images = (Images)runtime.getImages();
Image image = images.captureScreenRaw();
if (image != null) {
ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC4);
image.close();
Mat mat = imageWrapper.getMat();
if (rect != null) {
// 裁切图像
Mat croppedImage = new Mat(mat, rect);
mat.release();
mat = croppedImage;
}
List<DetectResult> results = this.predictYolo(mat);
mat.release();
return results;
}
return Collections.emptyList();
}
public YoloInstance createNcnn(String paramPath, String binPath, List<String> labels, Integer imageSize, boolean useGpu) {
NcnnInitParams params = new NcnnInitParams();
params.setParamPath(paramPath);
params.setBinPath(binPath);
params.setLabels(labels);
params.setImageSize(imageSize);
params.setUseGpu(useGpu);
return ncnnFactory.createInstance(params);
}
}

View File

@ -0,0 +1,94 @@
package com.stardust.autojs.yolo;
import android.util.Log;
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
import org.opencv.core.Mat;
import java.util.Collections;
import java.util.List;
/**
* BaseYoloInstance类是一个实现了YoloInstance接口的具体类用于封装YoloPredictor对象
* 并提供YOLO模型推理的核心功能包括预测设置阈值检查初始化状态以及释放资源
*
* @author TonyJiangWJ
* @since 2025/1/5
*/
public class BaseYoloInstance extends YoloInstance {
private final YoloPredictor predictor;
/**
* 构造函数初始化BaseYoloInstance实例
*
* @param predictor YoloPredictor对象用于执行YOLO模型的推理操作
*/
public BaseYoloInstance(YoloPredictor predictor) {
this.predictor = predictor;
}
/**
* 获取当前实例的YoloPredictor对象
*
* @return 返回封装的YoloPredictor对象
*/
@Override
public YoloPredictor getPredictor() {
return predictor;
}
/**
* 对输入的图像进行YOLO模型推理返回检测结果列表
*
* @param image 输入的图像数据类型为Mat通常来自OpenCV
* @return 返回检测结果列表如果推理失败则返回空列表
*/
@Override
public List<DetectResult> predictYolo(Mat image) {
try {
return predictor.predictYolo(image);
} catch (Exception e) {
Log.e("BaseYoloInstance", "predictYolo: failed", e);
return Collections.emptyList();
}
}
/**
* 设置YOLO模型的置信度阈值
*
* @param confThreshold 置信度阈值范围通常为0到1
*/
@Override
public void setConfThreshold(float confThreshold) {
predictor.setConfThreshold(confThreshold);
}
/**
* 设置YOLO模型的非极大值抑制NMS阈值
*
* @param nmsThreshold NMS阈值范围通常为0到1
*/
@Override
public void setNmsThreshold(float nmsThreshold) {
predictor.setNmsThreshold(nmsThreshold);
}
/**
* 检查YOLO模型是否已经初始化
*
* @return 如果模型已初始化返回true否则返回false
*/
@Override
public boolean isInit() {
return predictor.isInit();
}
/**
* 释放YOLO模型占用的资源
*/
@Override
public void release() {
predictor.release();
}
}

View File

@ -0,0 +1,40 @@
package com.stardust.autojs.yolo;
import java.util.List;
/**
* 用于存储模型初始化所需的参数
*
* @author TonyJiangWJ
* @since 2025/1/5
*/
public class ModelInitParams {
private String modelPath;
private List<String> labels;
private Integer imageSize;
public String getModelPath() {
return modelPath;
}
public void setModelPath(String modelPath) {
this.modelPath = modelPath;
}
public List<String> getLabels() {
return labels;
}
public void setLabels(List<String> labels) {
this.labels = labels;
}
public Integer getImageSize() {
return imageSize;
}
public void setImageSize(Integer imageSize) {
this.imageSize = imageSize;
}
}

View File

@ -0,0 +1,102 @@
package com.stardust.autojs.yolo;
import android.media.Image;
import com.stardust.autojs.core.image.ImageWrapper;
import com.stardust.autojs.runtime.ScriptRuntime;
import com.stardust.autojs.runtime.api.Images;
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Rect;
import java.util.Collections;
import java.util.List;
/**
* YoloInstance是一个抽象类定义了YOLO实例的基本行为和功能
* 该类提供了YOLO模型推理的核心方法包括预测设置阈值检查初始化状态释放资源以及捕获屏幕并预测的功能
*
* @author TonyJiangWJ
* @since 2025/1/5
*/
public abstract class YoloInstance {
/**
* 获取当前实例的YoloPredictor对象
*
* @return 返回封装的YoloPredictor对象
*/
public abstract YoloPredictor getPredictor();
/**
* 对输入的图像进行YOLO模型推理返回检测结果列表
*
* @param image 输入的图像数据类型为Mat通常来自OpenCV
* @return 返回检测结果列表
*/
public abstract List<DetectResult> predictYolo(Mat image);
/**
* 设置YOLO模型的置信度阈值
*
* @param confThreshold 置信度阈值范围通常为0到1
*/
public void setConfThreshold(float confThreshold) {
getPredictor().setConfThreshold(confThreshold);
}
/**
* 设置YOLO模型的非极大值抑制NMS阈值
*
* @param nmsThreshold NMS阈值范围通常为0到1
*/
public void setNmsThreshold(float nmsThreshold) {
getPredictor().setNmsThreshold(nmsThreshold);
}
/**
* 检查YOLO模型是否已经初始化
*
* @return 如果模型已初始化返回true否则返回false
*/
public boolean isInit() {
return getPredictor().isInit();
}
/**
* 释放YOLO模型占用的资源
*/
public void release() {
getPredictor().release();
}
/**
* 捕获屏幕图像并进行YOLO模型推理
*
* @param runtime 脚本运行时环境用于获取图像捕获功能
* @param rect 指定捕获屏幕的区域如果为null则捕获整个屏幕
* @return 返回检测结果列表如果捕获或推理失败则返回空列表
*/
public List<DetectResult> captureAndPredict(ScriptRuntime runtime, Rect rect) {
Images images = (Images) runtime.getImages();
Image image = images.captureScreenRaw();
if (image != null) {
ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC4);
image.close();
Mat mat = imageWrapper.getMat();
if (rect != null) {
// 裁切图像
Mat croppedImage = new Mat(mat, rect);
mat.release();
mat = croppedImage;
}
List<DetectResult> results = this.predictYolo(mat);
mat.release();
return results;
}
return Collections.emptyList();
}
}

View File

@ -0,0 +1,19 @@
package com.stardust.autojs.yolo;
/**
* yolo实例抽象工厂用于创建不同类型的yolo实例 目前支持ncnn和onnx的yolov8版本实例
*
* @param <P> 模型初始化参数
* @author TonyJiangWJ
* @since 2025/1/5
*/
public interface YoloInstanceFactory<P extends ModelInitParams> {
/**
* 创建yolo实例
*
* @param initParams 初始化参数
* @return 返回yolo实例
*/
YoloInstance createInstance(P initParams);
}

View File

@ -1,8 +1,11 @@
package com.stardust.autojs.runtime.api;
package com.stardust.autojs.yolo;
import android.util.Log;
import com.stardust.autojs.core.opencv.OpenCVHelper;
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
import org.opencv.core.Mat;
import java.util.List;
@ -10,7 +13,7 @@ import java.util.List;
* @author TonyJiangWJ
* @since 2024/6/1
*/
public class YoloPredictor {
public abstract class YoloPredictor {
static {
OpenCVHelper.initIfNeeded(null, () -> {
@ -54,6 +57,8 @@ public class YoloPredictor {
return init;
}
public abstract List<DetectResult> predictYolo(Mat image) throws Exception;
public void release() {
}

View File

@ -0,0 +1,34 @@
package com.stardust.autojs.yolo.ncnn;
import com.stardust.autojs.yolo.ModelInitParams;
public class NcnnInitParams extends ModelInitParams {
private String paramPath;
private String binPath;
private boolean useGpu;
public String getParamPath() {
return paramPath;
}
public void setParamPath(String paramPath) {
this.paramPath = paramPath;
}
public String getBinPath() {
return binPath;
}
public void setBinPath(String binPath) {
this.binPath = binPath;
}
public boolean isUseGpu() {
return useGpu;
}
public void setUseGpu(boolean useGpu) {
this.useGpu = useGpu;
}
}

View File

@ -0,0 +1,27 @@
package com.stardust.autojs.yolo.ncnn;
import android.os.Build;
import android.util.Log;
import com.stardust.autojs.yolo.BaseYoloInstance;
import com.stardust.autojs.yolo.YoloInstance;
import com.stardust.autojs.yolo.YoloInstanceFactory;
import androidx.annotation.RequiresApi;
public class NcnnYoloInstanceFactory implements YoloInstanceFactory<NcnnInitParams> {
@RequiresApi(api = Build.VERSION_CODES.N)
@Override
public YoloInstance createInstance(NcnnInitParams initParams) {
NcnnYoloV8Predictor predictor = new NcnnYoloV8Predictor(initParams.getParamPath(),
initParams.getBinPath(),
initParams.getLabels());
predictor.setShapeSize(initParams.getImageSize());
predictor.setUseGpu(initParams.isUseGpu());
Log.d("NcnnYoloInstanceFactory", "ncnnYoloV8 instance initializer: " + predictor.init());
return new BaseYoloInstance(predictor);
}
}

View File

@ -1,16 +1,15 @@
package com.stardust.autojs.ncnn;
package com.stardust.autojs.yolo.ncnn;
import android.os.Build;
import android.util.Log;
import com.google.android.gms.common.util.CollectionUtils;
import com.stardust.autojs.onnx.domain.DetectResult;
import com.stardust.autojs.runtime.api.YoloPredictor;
import com.tony.yolov8ncnn.PredictResult;
import com.stardust.autojs.yolo.YoloPredictor;
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
import com.tony.yolov8ncnn.NcnnPredictorNative;
import com.tony.yolov8ncnn.PredictResult;
import org.opencv.core.Mat;
import org.opencv.imgproc.Imgproc;
import java.util.Collections;
import java.util.List;
@ -22,6 +21,8 @@ import java.util.stream.Collectors;
import androidx.annotation.RequiresApi;
/**
* Ncnn YoloV8推理器
*
* @author TonyJiangWJ
* @since 2024/6/1
*/

View File

@ -0,0 +1,34 @@
package com.stardust.autojs.yolo.onnx;
import android.os.Build;
import com.stardust.autojs.yolo.BaseYoloInstance;
import com.stardust.autojs.yolo.ModelInitParams;
import com.stardust.autojs.yolo.YoloInstance;
import com.stardust.autojs.yolo.YoloInstanceFactory;
import androidx.annotation.RequiresApi;
/**
* OnnxYoloV8实例创建工厂
*
* @author TonyJiangWJ
* @since 2025/1/5
*/
public class OnnxYoloInstanceFactory implements YoloInstanceFactory<ModelInitParams> {
/**
* 创建YoloInstance实例
*
* @param modelInitParams 初始化参数
* @return
*/
@RequiresApi(api = Build.VERSION_CODES.N)
@Override
public YoloInstance createInstance(ModelInitParams modelInitParams) {
OnnxYoloV8Predictor predictor = new OnnxYoloV8Predictor(modelInitParams.getModelPath());
predictor.setLabels(modelInitParams.getLabels());
predictor.setShapeSize(modelInitParams.getImageSize(), modelInitParams.getImageSize());
return new BaseYoloInstance(predictor);
}
}

View File

@ -0,0 +1,375 @@
package com.stardust.autojs.yolo.onnx;
import android.os.Build;
import android.util.Log;
import com.google.gson.Gson;
import com.stardust.autojs.yolo.YoloPredictor;
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
import com.stardust.autojs.yolo.onnx.domain.Detection;
import com.stardust.autojs.yolo.onnx.util.Letterbox;
import org.opencv.core.CvType;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.providers.NNAPIFlags;
import androidx.annotation.RequiresApi;
/**
* @author TonyJiangWJ
* @since 2023/8/20
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
*/
@RequiresApi(api = Build.VERSION_CODES.N)
public class OnnxYoloV8Predictor extends YoloPredictor {
private static final String TAG = "YoloV8Predictor";
private static final Pattern IMG_SIZE_PATTERN = Pattern.compile("\\[(\\d+), \\d+]");
private static final Pattern LABEL_PATTERN = Pattern.compile("'([^']*)'");
private final String modelPath;
private boolean tryNpu;
private Size shapeSize = new Size(640, 640);
private Letterbox letterbox;
private List<String> apiFlags = Arrays.asList("CPU_DISABLED");
public OnnxYoloV8Predictor(String modelPath) {
this.modelPath = modelPath;
init = true;
}
public OnnxYoloV8Predictor(String modelPath, float confThreshold, float nmsThreshold) {
this.modelPath = modelPath;
this.confThreshold = confThreshold;
this.nmsThreshold = nmsThreshold;
init = true;
}
public void setShapeSize(double width, double height) {
this.shapeSize = new Size(width, height);
}
public void setTryNpu(boolean tryNpu) {
this.tryNpu = tryNpu;
}
public void setApiFlags(List<String> apiFlags) {
this.apiFlags = apiFlags;
}
private OrtSession session;
private OrtEnvironment environment;
private void prepareSession() throws OrtException {
if (environment != null) {
return;
}
// 加载ONNX模型
environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
addNNApiProvider(sessionOptions);
session = environment.createSession(modelPath, sessionOptions);
// 输出基本信息
session.getInputInfo().keySet().forEach(x -> {
try {
System.out.println("input name = " + x);
System.out.println(session.getInputInfo().get(x).getInfo().toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
// 如果入参labels无效或未定义使用模型内置labels
if (labels == null || labels.size() == 0) {
labels = initLabels(session);
}
initShapeSize(session);
}
private List<String> initLabels(OrtSession session) throws OrtException {
String meteStr = session.getMetadata().getCustomMetadata().get("names");
if (meteStr == null) {
Log.d(TAG, "initLabels: 读取names失败 无法自动修正labels");
return Collections.emptyList();
}
String[] labels = new String[meteStr.split(",").length];
Matcher matcher = LABEL_PATTERN.matcher(meteStr);
int h = 0;
while (matcher.find()) {
labels[h] = matcher.group(1);
h++;
}
return Arrays.asList(labels);
}
private void initShapeSize(OrtSession session) throws OrtException {
String meteStr = session.getMetadata().getCustomMetadata().get("imgsz");
Log.d(TAG, "initShapeSize: " + meteStr);
if (meteStr == null) {
Log.d(TAG, "initShapeSize: 读取imgsz失败 无法自动修正输入大小");
return;
}
Matcher matcher = IMG_SIZE_PATTERN.matcher(meteStr);
if (matcher.find()) {
String shapeSize = matcher.group(1);
if (shapeSize == null) {
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
return;
}
this.shapeSize = new Size(Double.parseDouble(shapeSize), Double.parseDouble(shapeSize));
Log.d(TAG, "set shape size: " + shapeSize);
} else {
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
}
}
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
if (!tryNpu) {
return;
}
try {
List<NNAPIFlags> flags = new ArrayList<>();
if (apiFlags.contains("USE_FP16")) {
flags.add(NNAPIFlags.USE_FP16);
}
if (apiFlags.contains("USE_NCHW")) {
flags.add(NNAPIFlags.USE_NCHW);
}
if (apiFlags.contains("CPU_ONLY")) {
flags.add(NNAPIFlags.CPU_ONLY);
}
if (apiFlags.contains("CPU_DISABLED")) {
flags.add(NNAPIFlags.CPU_DISABLED);
}
Log.d(TAG, "addNNApiProvider: 当前启用nnapiFlags:" + new Gson().toJson(apiFlags));
sessionOptions.addNnapi(EnumSet.copyOf(flags));
Log.d(TAG, "prepareSession: 启用nnapi成功");
} catch (Exception e) {
Log.e(TAG, "prepareSession: 无法启用nnapi");
}
}
private HashMap<String, OnnxTensor> preprocessImage(Mat img) throws OrtException {
// 读取 image
Mat image = img.clone();
// 将四通道转换为三通道
if (image.channels() == 4) {
Imgproc.cvtColor(image, image, Imgproc.COLOR_RGBA2RGB);
}
Log.d(TAG, "preprocessImage: image's channels: " + image.channels());
// 更改 image 尺寸
letterbox = new Letterbox();
letterbox.setNewShape(this.shapeSize);
image = letterbox.letterbox(image);
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
// 转换Mat对象的数据类型为CV_64F即64位浮点型
Mat convertedImage = new Mat();
image.convertTo(convertedImage, CvType.CV_64F);
// 获取整个像素数据
double[] pixelData = new double[rows * cols * channels];
convertedImage.get(0, 0, pixelData);
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
for (int k = 0; k < channels; k++) {
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
// 重新组织内存访问模式提高缓存效率
pixels[k * rows * cols + i * cols + j] = (float) (pixelData[(i * cols + j) * channels + k] / 255.0);
}
}
}
image.release();
convertedImage.release();
// 创建OnnxTensor对象
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
return stringOnnxTensorHashMap;
}
private List<Detection> postProcessOutput(OrtSession.Result output) throws OrtException {
float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
outputData = transposeMatrix(outputData);
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
for (float[] bbox : outputData) {
int label = argmax(bbox, 4); // 直接在原数组上进行操作
float conf = bbox[label + 4];
if (conf < confThreshold) {
continue;
}
bbox[4] = conf;
// xywh to (x1, y1, x2, y2)
xywh2xyxy(bbox);
// skip invalid predictions
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {
continue;
}
class2Bbox.computeIfAbsent(label, k -> new ArrayList<>()).add(bbox);
}
List<Detection> detections = new ArrayList<>();
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
int label = entry.getKey();
List<float[]> bboxes = entry.getValue();
bboxes = nonMaxSuppression(bboxes, nmsThreshold);
for (float[] bbox : bboxes) {
String labelString = "";
if (labels.size() - 1 < label) {
labelString = String.valueOf(label);
} else {
labelString = labels.get(label);
}
detections.add(new Detection(labelString, entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
}
}
return detections;
}
public List<DetectResult> predictYolo(String imagePath) throws OrtException {
return predictYolo(Imgcodecs.imread(imagePath));
}
public List<DetectResult> predictYolo(Mat image) throws OrtException {
prepareSession();
long start_time = System.currentTimeMillis();
Map<String, OnnxTensor> inputMap = preprocessImage(image);
// 运行推理
try (OrtSession.Result output = session.run(inputMap)) {
Log.d(TAG, "predictYolo: onnx run cost " + (System.currentTimeMillis() - start_time) + "ms");
List<Detection> detections = postProcessOutput(output);
Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
return detections.stream().map(detection -> new DetectResult(detection, letterbox))
.collect(Collectors.toList());
} finally {
// 释放资源
inputMap.values().forEach(OnnxTensor::close);
}
}
public static void xywh2xyxy(float[] bbox) {
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
bbox[0] = x - w * 0.5f;
bbox[1] = y - h * 0.5f;
bbox[2] = x + w * 0.5f;
bbox[3] = y + h * 0.5f;
}
public static float[][] transposeMatrix(float[][] m) {
float[][] temp = new float[m[0].length][m.length];
for (int i = 0; i < m.length; i++) {
for (int j = 0; j < m[0].length; j++) {
temp[j][i] = m[i][j];
}
}
return temp;
}
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
long start = System.currentTimeMillis();
List<float[]> bestBboxes = new ArrayList<>();
bboxes.sort(Comparator.comparing(a -> a[4]));
while (!bboxes.isEmpty()) {
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
bestBboxes.add(bestBbox);
bboxes.removeIf(bbox -> computeIOU(bbox, bestBbox) >= iouThreshold);
}
Log.d(TAG, "nonMaxSuppression: cost " + (System.currentTimeMillis() - start) + "ms");
return bestBboxes;
}
public static float computeIOU(float[] box1, float[] box2) {
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
float left = Math.max(box1[0], box2[0]);
float top = Math.max(box1[1], box2[1]);
float right = Math.min(box1[2], box2[2]);
float bottom = Math.min(box1[3], box2[3]);
// 计算交集区域的宽度和高度
float width = Math.max(right - left, 0);
float height = Math.max(bottom - top, 0);
// 计算交集面积和并集面积
float interArea = width * height;
float unionArea = area1 + area2 - interArea;
// 计算交并比
return Math.max(interArea / unionArea, 1e-8f);
}
//返回最大值的索引
// 优化后的 argmax 函数
public static int argmax(float[] a, int start) {
float re = -Float.MAX_VALUE;
int arg = -1;
for (int i = start; i < a.length; i++) {
if (a[i] >= re) {
re = a[i];
arg = i - start;
}
}
return arg;
}
@Override
public void release() {
this.init = false;
if (session != null) {
try {
session.close();
session = null;
} catch (OrtException e) {
Log.e(TAG, "close session failed" + e);
}
environment.close();
environment = null;
}
}
}

View File

@ -1,12 +1,13 @@
package com.stardust.autojs.onnx.domain;
package com.stardust.autojs.yolo.onnx.domain;
import android.graphics.Rect;
import com.stardust.autojs.onnx.util.Letterbox;
import com.stardust.autojs.yolo.onnx.util.Letterbox;
/**
* @author TonyJiangWJ
* @since 2023/8/20
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
* transfer from <a href="https://gitee.com/agricultureiot/yolo-onnx-java">yolo-onnx-java</a>
*/
public class DetectResult {

View File

@ -1,4 +1,4 @@
package com.stardust.autojs.onnx.domain;
package com.stardust.autojs.yolo.onnx.domain;
/**
* @author TonyJiangWJ
@ -15,14 +15,14 @@ public class Detection {
public float confidence;
public Detection(String label,Integer clsId, float[] bbox, float confidence){
public Detection(String label, Integer clsId, float[] bbox, float confidence) {
this.clsId = clsId;
this.label = label;
this.bbox = bbox;
this.confidence = confidence;
}
public Detection(){
public Detection() {
}
@ -52,12 +52,12 @@ public class Detection {
@Override
public String toString() {
return " label="+label +
" \t clsId="+clsId +
" \t x0="+bbox[0] +
" \t y0="+bbox[1] +
" \t x1="+bbox[2] +
" \t y1="+bbox[3] +
" \t score="+confidence;
return " label=" + label +
" \t clsId=" + clsId +
" \t x0=" + bbox[0] +
" \t y0=" + bbox[1] +
" \t x1=" + bbox[2] +
" \t y1=" + bbox[3] +
" \t score=" + confidence;
}
}

View File

@ -1,4 +1,4 @@
package com.stardust.autojs.onnx.util;
package com.stardust.autojs.yolo.onnx.util;
import org.opencv.core.Core;
import org.opencv.core.Mat;
@ -8,7 +8,7 @@ import org.opencv.imgproc.Imgproc;
/**
* @author TonyJiangWJ
* @since 2023/8/20
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
* transfer from <a href="https://gitee.com/agricultureiot/yolo-onnx-java">yolo-onnx-java</a>
*/
public class Letterbox {