mirror of
https://github.com/TonyJiangWJ/Auto.js.git
synced 2026-06-12 21:01:32 +08:00
YOLO相关代码使用抽象工厂重构提高后续扩展和可维护性
This commit is contained in:
parent
020258392f
commit
5aa55cda3e
@ -1,5 +1,8 @@
|
||||
NCNN和PaddleOCR同时使用时有兼容性问题导致无法正常运行甚至闪退,请勿同时使用
|
||||
NCNN和PaddleOCR同时使用时,
|
||||
有兼容性问题导致无法正常运行甚至闪退,请勿同时使用
|
||||
|
||||
该问题暂时无法解决,因此如果需要使用PaddleOcr时 请使用onnx,或者使用mlkitocr进行文字识别
|
||||
该问题暂时无法解决,因此如果需要使用PaddleOcr时
|
||||
请使用onnx,或者使用mlkitocr进行文字识别
|
||||
|
||||
建议在依赖性能的情况下使用ncnn,同时使用mlkitocr进行文字识别
|
||||
建议在依赖性能的情况下使用ncnn,
|
||||
同时使用mlkitocr进行文字识别
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
const img = images.read("./test.png")
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
setTimeout(() => console.hide(), 15000)
|
||||
let cpuThreadNum = 4
|
||||
// PaddleOCR 移动端提供了两种模型:ocr_v3_for_cpu与ocr_v3_for_cpu(slim),此选项用于选择加载的模型,默认true使用v3的slim版(速度更快),false使用v3的普通版(准确率更高)
|
||||
let useSlim = true
|
||||
@ -14,8 +14,8 @@ let result = $ocr.detect(img, { cpuThreadNum, useSlim })
|
||||
img.recycle()
|
||||
log('slim识别耗时:' + (new Date() - start) + 'ms')
|
||||
|
||||
let model_path = '/sdcard/脚本/best.bin'
|
||||
let param_path = '/sdcard/脚本/best.param'
|
||||
let model_path = '/sdcard/脚本/manor.bin'
|
||||
let param_path = '/sdcard/脚本/manor.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
let model_path = '/sdcard/脚本/best.bin'
|
||||
let param_path = '/sdcard/脚本/best.param'
|
||||
setTimeout(() => console.hide(), 15000)
|
||||
let model_path = '/sdcard/脚本/manor.bin'
|
||||
let param_path = '/sdcard/脚本/manor.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 134 KiB |
@ -1,39 +0,0 @@
|
||||
let model_path = '/sdcard/脚本/yolov8n.bin'
|
||||
let param_path = '/sdcard/脚本/yolov8n.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
}
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
let yoloInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
useGpu: true,
|
||||
paramPath: files.path(param_path),
|
||||
binPath: files.path(model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
||||
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
||||
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
||||
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
||||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"
|
||||
]
|
||||
})
|
||||
|
||||
if (!yoloInit) {
|
||||
toast('初始化失败')
|
||||
exit()
|
||||
}
|
||||
|
||||
const img = images.read("./bus.jpg")
|
||||
let start = new Date()
|
||||
const result = $yolo.forward(img)
|
||||
toastLog('ncnn cost: ' + (new Date() - start) + 'ms')
|
||||
log('predict result:' + JSON.stringify(result, null, 4))
|
||||
|
||||
img.recycle()
|
||||
@ -1,38 +0,0 @@
|
||||
let model_path = '/sdcard/脚本/yolov8n.bin'
|
||||
let param_path = '/sdcard/脚本/yolov8n.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
}
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
let yoloInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
paramPath: files.path(param_path),
|
||||
binPath: files.path(model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
|
||||
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
||||
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
||||
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
|
||||
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
||||
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
||||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
|
||||
"hair drier", "toothbrush"
|
||||
]
|
||||
})
|
||||
|
||||
if (!yoloInit) {
|
||||
toast('初始化失败')
|
||||
exit()
|
||||
}
|
||||
|
||||
const img = images.read("./bus.jpg")
|
||||
let start = new Date()
|
||||
const result = $yolo.forward(img)
|
||||
toastLog('ncnn cost: ' + (new Date() - start) + 'ms')
|
||||
log('predict result:' + JSON.stringify(result, null, 4))
|
||||
|
||||
img.recycle()
|
||||
@ -1,11 +1,11 @@
|
||||
let model_path = '/sdcard/脚本/best.bin'
|
||||
let param_path = '/sdcard/脚本/best.param'
|
||||
let model_path = '/sdcard/脚本/manor.bin'
|
||||
let param_path = '/sdcard/脚本/manor.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
}
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
setTimeout(() => console.hide(), 15000)
|
||||
let yoloInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
paramPath: files.path(param_path),
|
||||
|
||||
@ -4,7 +4,7 @@ if (!files.exists(model_path)) {
|
||||
exit()
|
||||
}
|
||||
console.show()
|
||||
setTimeout(() -> console.hide(), 15000)
|
||||
setTimeout(() => console.hide(), 15000)
|
||||
let yoloInit = $yolo.init({
|
||||
type: 'onnx',
|
||||
modelPath: files.path(model_path),
|
||||
|
||||
@ -1,307 +0,0 @@
|
||||
// 杀死当前同名脚本 see AutoScriptBase/lib/killMyDuplicator
|
||||
(() => { let g = engines.myEngine(); var e = engines.all(), n = e.length; let r = g.getSource() + ""; 1 < n && e.forEach(e => { var n = e.getSource() + ""; g.id !== e.id && n == r && e.forceStop() }) })();
|
||||
|
||||
if (!requestScreenCapture()) {
|
||||
toastLog('请求截图权限失败')
|
||||
exit()
|
||||
}
|
||||
|
||||
|
||||
let onnxInstance = null
|
||||
let ncnnInstance = null
|
||||
let currentType = 'ncnn'
|
||||
initYoloInstances()
|
||||
let yoloInstance = {
|
||||
ncnn: ncnnInstance,
|
||||
onnx: onnxInstance,
|
||||
}
|
||||
|
||||
// 识别结果和截图信息
|
||||
let result = []
|
||||
let img = null
|
||||
let running = true
|
||||
let capturing = false
|
||||
|
||||
/**
|
||||
* 截图并识别OCR文本信息
|
||||
*/
|
||||
function captureAndDetect () {
|
||||
capturing = true
|
||||
img = captureScreen()
|
||||
if (!img) {
|
||||
toastLog('截图失败')
|
||||
}
|
||||
let start = new Date()
|
||||
result = yoloInstance[currentType].forward(img)
|
||||
console.verbose('识别结果:' + JSON.stringify(result))
|
||||
toastLog('耗时' + (new Date() - start) + 'ms')
|
||||
img && img.recycle()
|
||||
capturing = false
|
||||
}
|
||||
|
||||
// captureAndDetect()
|
||||
|
||||
// 获取状态栏高度
|
||||
let offset = -getStatusBarHeightCompat()
|
||||
|
||||
// 绘制识别结果
|
||||
let window = floaty.rawWindow(
|
||||
<canvas id="canvas" layout_weight="1" />
|
||||
);
|
||||
|
||||
// 设置悬浮窗位置
|
||||
ui.post(() => {
|
||||
window.setPosition(0, offset)
|
||||
window.setSize(device.width, device.height)
|
||||
window.setTouchable(false)
|
||||
})
|
||||
|
||||
// 操作按钮
|
||||
let clickButtonWindow = floaty.rawWindow(
|
||||
<vertical>
|
||||
<button id="changeYoloType" text="当前ncnn" />
|
||||
<button id="captureAndDetect" text="截图识别" />
|
||||
<button id="closeBtn" text="退出" />
|
||||
</vertical>
|
||||
);
|
||||
ui.run(function () {
|
||||
clickButtonWindow.setPosition(device.width / 2 - ~~(clickButtonWindow.getWidth() / 2), device.height * 0.65)
|
||||
})
|
||||
|
||||
// 切换类型
|
||||
clickButtonWindow.changeYoloType.click(function () {
|
||||
threads.start(function () {
|
||||
changeYoloType()
|
||||
ui.run(function () {
|
||||
if (currentType === 'onnx') {
|
||||
clickButtonWindow.changeYoloType.setText('当前onnx')
|
||||
} else {
|
||||
clickButtonWindow.changeYoloType.setText('当前ncnn')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
// 点击识别
|
||||
clickButtonWindow.captureAndDetect.click(function () {
|
||||
if (capturing) {
|
||||
return
|
||||
}
|
||||
result = []
|
||||
let oldPosition = {
|
||||
x: clickButtonWindow.getX(),
|
||||
y: clickButtonWindow.getY(),
|
||||
}
|
||||
ui.run(function () {
|
||||
clickButtonWindow.setPosition(device.width, device.height)
|
||||
})
|
||||
setTimeout(() => {
|
||||
captureAndDetect()
|
||||
ui.run(function () {
|
||||
clickButtonWindow.setPosition(oldPosition.x, oldPosition.y)
|
||||
})
|
||||
}, 500)
|
||||
})
|
||||
|
||||
// 点击关闭
|
||||
clickButtonWindow.closeBtn.setOnTouchListener(new TouchController(clickButtonWindow, () => {
|
||||
exit()
|
||||
}).createListener())
|
||||
|
||||
let Typeface = android.graphics.Typeface
|
||||
let paint = new Paint()
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setTypeface(Typeface.DEFAULT_BOLD)
|
||||
paint.setTextAlign(Paint.Align.LEFT)
|
||||
paint.setAntiAlias(true)
|
||||
paint.setStrokeJoin(Paint.Join.ROUND)
|
||||
paint.setDither(true)
|
||||
window.canvas.on('draw', function (canvas) {
|
||||
if (!running || capturing) {
|
||||
return
|
||||
}
|
||||
// 清空内容
|
||||
canvas.drawColor(0xFFFFFF, android.graphics.PorterDuff.Mode.CLEAR)
|
||||
if (result && result.length > 0) {
|
||||
for (let i = 0; i < result.length; i++) {
|
||||
let detectResult = result[i]
|
||||
drawRectAndText(detectResult.label, detectResult.bounds, '#00ff00', canvas, paint)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
setInterval(() => { }, 10000)
|
||||
events.on('exit', () => {
|
||||
// 标记停止 避免canvas导致闪退
|
||||
running = false
|
||||
// 撤销监听
|
||||
window.canvas.removeAllListeners()
|
||||
// 回收图片
|
||||
img && img.recycle()
|
||||
})
|
||||
|
||||
/**
|
||||
* 绘制文本和方框
|
||||
*
|
||||
* @param {*} desc
|
||||
* @param {*} rect
|
||||
* @param {*} colorStr
|
||||
* @param {*} canvas
|
||||
* @param {*} paint
|
||||
*/
|
||||
function drawRectAndText (desc, rect, colorStr, canvas, paint) {
|
||||
let color = colors.parseColor(colorStr)
|
||||
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setStyle(Paint.Style.STROKE)
|
||||
// 反色
|
||||
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
|
||||
canvas.drawRect(rect, paint)
|
||||
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setTextSize(20)
|
||||
paint.setStyle(Paint.Style.FILL)
|
||||
canvas.drawText(desc, rect.left, rect.top, paint)
|
||||
paint.setTextSize(10)
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setARGB(255, 0, 0, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取状态栏高度
|
||||
*
|
||||
* @returns
|
||||
*/
|
||||
function getStatusBarHeightCompat () {
|
||||
let result = 0
|
||||
let resId = context.getResources().getIdentifier("status_bar_height", "dimen", "android")
|
||||
if (resId > 0) {
|
||||
result = context.getResources().getDimensionPixelOffset(resId)
|
||||
}
|
||||
if (result <= 0) {
|
||||
result = context.getResources().getDimensionPixelOffset(R.dimen.dimen_25dp)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
function initYoloInstances () {
|
||||
|
||||
|
||||
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
|
||||
if (!files.exists(onnx_model_path)) {
|
||||
toastLog('请确认已下载了onnx模型文件')
|
||||
exit()
|
||||
}
|
||||
let model_path = '/sdcard/脚本/best.bin'
|
||||
let param_path = '/sdcard/脚本/best.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
exit()
|
||||
}
|
||||
let ncnnInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
paramPath: files.path(param_path),
|
||||
binPath: files.path(model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
|
||||
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
|
||||
]
|
||||
})
|
||||
if (ncnnInit) {
|
||||
ncnnInstance = $yolo.getInstance()
|
||||
} else {
|
||||
toastLog('ncnn初始化失败')
|
||||
}
|
||||
let onnxInit = $yolo.init({
|
||||
type: 'onnx',
|
||||
modelPath: files.path(onnx_model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
|
||||
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken', 'bring_back',
|
||||
'leave_msg', 'speedup_eating',
|
||||
]
|
||||
})
|
||||
if (onnxInit) {
|
||||
onnxInstance = $yolo.getInstance()
|
||||
} else {
|
||||
toastLog('onnx初始化失败')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function changeYoloType () {
|
||||
let options = ["ncnn", "onnx"]
|
||||
let idx = dialogs.singleChoice("请选择YOLO推理类型", options, options.indexOf(currentType))
|
||||
let targetType = options[idx]
|
||||
toast("选择了: " + targetType)
|
||||
if (!yoloInstance[targetType]) {
|
||||
toastLog('目标类型未能初始化:' + targetType)
|
||||
return
|
||||
}
|
||||
currentType = targetType
|
||||
}
|
||||
|
||||
function TouchController (buttonWindow, handleClick, handleDown, handleUp) {
|
||||
this.eventStartX = null
|
||||
this.eventStartY = null
|
||||
this.windowStartX = buttonWindow.getX()
|
||||
this.windowStartY = buttonWindow.getY()
|
||||
this.eventKeep = false
|
||||
this.eventMoving = false
|
||||
this.touchDownTime = new Date().getTime()
|
||||
|
||||
this.createListener = function () {
|
||||
let _this = this
|
||||
return new android.view.View.OnTouchListener((view, event) => {
|
||||
try {
|
||||
switch (event.getAction()) {
|
||||
case event.ACTION_DOWN:
|
||||
handleDown && handleDown()
|
||||
_this.eventStartX = event.getRawX();
|
||||
_this.eventStartY = event.getRawY();
|
||||
_this.windowStartX = buttonWindow.getX();
|
||||
_this.windowStartY = buttonWindow.getY();
|
||||
_this.eventKeep = true; //按下,开启计时
|
||||
_this.touchDownTime = new Date().getTime()
|
||||
break;
|
||||
case event.ACTION_MOVE:
|
||||
var sx = event.getRawX() - _this.eventStartX;
|
||||
var sy = event.getRawY() - _this.eventStartY;
|
||||
if (!_this.eventMoving && _this.eventKeep && getDistance(sx, sy) >= 10) {
|
||||
_this.eventMoving = true;
|
||||
}
|
||||
if (_this.eventMoving && _this.eventKeep) {
|
||||
ui.post(() => {
|
||||
buttonWindow.setPosition(_this.windowStartX + sx, _this.windowStartY + sy);
|
||||
})
|
||||
}
|
||||
break;
|
||||
case event.ACTION_UP:
|
||||
handleUp && handleUp()
|
||||
if (!_this.eventMoving && _this.eventKeep && _this.touchDownTime > new Date().getTime() - 1000) {
|
||||
handleClick && handleClick()
|
||||
}
|
||||
_this.eventKeep = false;
|
||||
_this.touchDownTime = 0;
|
||||
_this.eventMoving = false;
|
||||
break;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('异常' + e)
|
||||
}
|
||||
return true;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function getDistance (dx, dy) {
|
||||
return Math.sqrt(Math.pow(dx, 2) + Math.pow(dy, 2));
|
||||
}
|
||||
@ -7,34 +7,36 @@ if (!requestScreenCapture()) {
|
||||
}
|
||||
|
||||
|
||||
let onnxInstance = null
|
||||
let ncnnInstance = null
|
||||
let currentType = 'ncnn'
|
||||
let initSuccess = false
|
||||
initYoloInstances()
|
||||
|
||||
let yoloInstance = {
|
||||
ncnn: ncnnInstance,
|
||||
onnx: onnxInstance,
|
||||
}
|
||||
|
||||
// 识别结果和截图信息
|
||||
let result = []
|
||||
let img = null
|
||||
let running = true
|
||||
let capturing = false
|
||||
let cost = 0
|
||||
|
||||
/**
|
||||
* 截图并识别OCR文本信息
|
||||
*/
|
||||
function captureAndDetect () {
|
||||
if (!initSuccess) {
|
||||
toastLog('当前推理模型未能初始化,请选择另一个')
|
||||
return
|
||||
}
|
||||
capturing = true
|
||||
img = captureScreen()
|
||||
if (!img) {
|
||||
toastLog('截图失败')
|
||||
}
|
||||
let start = new Date()
|
||||
result = $yolo.forward(img)
|
||||
result = yoloInstance[currentType].forward(img)
|
||||
console.verbose('识别结果:' + JSON.stringify(result))
|
||||
toastLog('耗时' + (new Date() - start) + 'ms')
|
||||
cost = (new Date() - start)
|
||||
toastLog('耗时' + cost + 'ms')
|
||||
img && img.recycle()
|
||||
capturing = false
|
||||
}
|
||||
@ -127,6 +129,10 @@ window.canvas.on('draw', function (canvas) {
|
||||
drawRectAndText(detectResult.label, detectResult.bounds, '#00ff00', canvas, paint)
|
||||
}
|
||||
}
|
||||
drawText('请打开支付宝蚂蚁庄园界面进行识别', 100, device.height - 300, '#00ff00', canvas, paint)
|
||||
if (cost > 0) {
|
||||
drawText('识别耗时:' + cost + 'ms', 100, device.height - 250, '#00ff00', canvas, paint)
|
||||
}
|
||||
})
|
||||
|
||||
setInterval(() => { }, 10000)
|
||||
@ -156,16 +162,42 @@ function drawRectAndText (desc, rect, colorStr, canvas, paint) {
|
||||
// 反色
|
||||
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
|
||||
canvas.drawRect(rect, paint)
|
||||
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setTextSize(20)
|
||||
paint.setStyle(Paint.Style.FILL)
|
||||
canvas.drawText(desc, rect.left + 1, rect.top + 2, paint)
|
||||
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
|
||||
canvas.drawText(desc, rect.left, rect.top, paint)
|
||||
paint.setTextSize(10)
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setARGB(255, 0, 0, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 绘制文本
|
||||
*
|
||||
* @param {*} desc
|
||||
* @param {*} left
|
||||
* @param {*} top
|
||||
* @param {*} colorStr
|
||||
* @param {*} canvas
|
||||
* @param {*} paint
|
||||
*/
|
||||
function drawText (desc, left, top, colorStr, canvas, paint) {
|
||||
let color = colors.parseColor(colorStr)
|
||||
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setStyle(Paint.Style.STROKE)
|
||||
paint.setStrokeWidth(1)
|
||||
paint.setTextSize(30)
|
||||
paint.setStyle(Paint.Style.FILL)
|
||||
// 反色 阴影
|
||||
paint.setARGB(255, 255 - (color >> 16 & 0xff), 255 - (color >> 8 & 0xff), 255 - (color & 0xff))
|
||||
canvas.drawText(desc, left + 1, top + 2, paint)
|
||||
paint.setARGB(255, color >> 16 & 0xff, color >> 8 & 0xff, color & 0xff)
|
||||
canvas.drawText(desc, left, top, paint)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取状态栏高度
|
||||
*
|
||||
@ -185,61 +217,55 @@ function getStatusBarHeightCompat () {
|
||||
|
||||
|
||||
function initYoloInstances () {
|
||||
if (initSuccess) {
|
||||
$yolo.release()
|
||||
|
||||
|
||||
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
|
||||
if (!files.exists(onnx_model_path)) {
|
||||
toastLog('请确认已下载了onnx模型文件')
|
||||
exit()
|
||||
}
|
||||
initSuccess = false
|
||||
if (currentType == 'ncnn') {
|
||||
let model_path = '/sdcard/脚本/best.bin'
|
||||
let param_path = '/sdcard/脚本/best.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了模型文件')
|
||||
return
|
||||
}
|
||||
let ncnnInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
paramPath: files.path(param_path),
|
||||
binPath: files.path(model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
|
||||
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
|
||||
]
|
||||
})
|
||||
if (ncnnInit) {
|
||||
initSuccess = true
|
||||
} else {
|
||||
toastLog('ncnn初始化失败')
|
||||
}
|
||||
let model_path = '/sdcard/脚本/manor.bin'
|
||||
let param_path = '/sdcard/脚本/manor.param'
|
||||
if (!files.exists(model_path) || !files.exists(param_path)) {
|
||||
toastLog('请确认已下载了ncnn模型文件')
|
||||
exit()
|
||||
}
|
||||
let ncnnInit = $yolo.init({
|
||||
type: 'ncnn',
|
||||
paramPath: files.path(param_path),
|
||||
binPath: files.path(model_path),
|
||||
imageSize: 480,
|
||||
// ncnn 版本必须填写labels
|
||||
labels: [
|
||||
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
|
||||
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
|
||||
]
|
||||
})
|
||||
if (ncnnInit) {
|
||||
ncnnInstance = $yolo.getInstance()
|
||||
} else {
|
||||
|
||||
let onnx_model_path = '/sdcard/脚本/manor_lite.onnx'
|
||||
if (!files.exists(onnx_model_path)) {
|
||||
toastLog('请确认已下载了onnx模型文件')
|
||||
return
|
||||
}
|
||||
|
||||
let onnxInit = $yolo.init({
|
||||
type: 'onnx',
|
||||
modelPath: files.path(onnx_model_path),
|
||||
imageSize: 480,
|
||||
labels: [
|
||||
'booth_btn', 'collect_coin', 'collect_egg', 'collect_food', 'cook', 'countdown', 'donate',
|
||||
'eating_chicken', 'employ', 'empty_booth', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', 'item', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken', 'bring_back',
|
||||
'leave_msg', 'speedup_eating',
|
||||
]
|
||||
})
|
||||
if (onnxInit) {
|
||||
initSuccess = true
|
||||
} else {
|
||||
toastLog('onnx初始化失败')
|
||||
}
|
||||
toastLog('ncnn初始化失败')
|
||||
}
|
||||
let onnxInit = $yolo.init({
|
||||
type: 'onnx',
|
||||
modelPath: files.path(onnx_model_path),
|
||||
imageSize: 480,
|
||||
// onnx版本可以不填写labels,可以通过onnx模型自动提取,当然也可以自己提供,比如映射成中文等
|
||||
labels: [
|
||||
'摆摊按钮', '收集金币', '收蛋', '领饲料', '去做饭', '倒计时', '捐蛋',
|
||||
'eating_chicken', 'employ', '空摊位', 'feed_btn', 'friend_btn', 'has_food', 'has_shit',
|
||||
'hungry_chicken', '道具', 'kick-out', 'no_food', 'not_ready', 'operation_booth', 'plz-go',
|
||||
'punish_booth', 'punish_btn', 'signboard', 'sleep', 'speedup', 'sports', 'stopped_booth',
|
||||
'thief_chicken', 'close_btn', 'collect_muck', 'confirm_btn', 'working_chicken',
|
||||
]
|
||||
})
|
||||
if (onnxInit) {
|
||||
onnxInstance = $yolo.getInstance()
|
||||
} else {
|
||||
toastLog('onnx初始化失败')
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,8 +275,11 @@ function changeYoloType () {
|
||||
let idx = dialogs.singleChoice("请选择YOLO推理类型", options, options.indexOf(currentType))
|
||||
let targetType = options[idx]
|
||||
toast("选择了: " + targetType)
|
||||
if (!yoloInstance[targetType]) {
|
||||
toastLog('目标类型未能初始化:' + targetType)
|
||||
return
|
||||
}
|
||||
currentType = targetType
|
||||
initYoloInstances()
|
||||
}
|
||||
|
||||
function TouchController (buttonWindow, handleClick, handleDown, handleUp) {
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
存放路径 下载地址
|
||||
存放路径 下载地址:https://pan.quark.cn/s/7242eae30941
|
||||
ncnn:
|
||||
/sdcard/脚本/yolov8n.param
|
||||
/sdcard/脚本/yolov8n.bin
|
||||
/sdcard/脚本/best.param
|
||||
/sdcard/脚本/best.bin
|
||||
/sdcard/脚本/manor.param
|
||||
/sdcard/脚本/manor.bin
|
||||
onnx:
|
||||
/sdcard/脚本/manor_lite.onnx
|
||||
|
||||
|
||||
2
autojs/proguard-rules.pro
vendored
2
autojs/proguard-rules.pro
vendored
@ -25,3 +25,5 @@
|
||||
#-renamesourcefileattribute SourceFile
|
||||
#-keep public class com.stardust.autojs.onnx.YoloV8Predictor
|
||||
#-keepnames class com.stardust.autojs.onnx.YoloV8Predictor
|
||||
#-keep public class com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor
|
||||
#-keepnames class com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor
|
||||
|
||||
@ -103,7 +103,6 @@ public class Mat extends org.opencv.core.Mat implements ResourceMonitor.Resource
|
||||
}
|
||||
mReleased = true;
|
||||
}
|
||||
super.finalize();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -1,374 +1,25 @@
|
||||
package com.stardust.autojs.onnx;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.stardust.autojs.onnx.domain.DetectResult;
|
||||
import com.stardust.autojs.onnx.domain.Detection;
|
||||
import com.stardust.autojs.onnx.util.Letterbox;
|
||||
import com.stardust.autojs.runtime.api.YoloPredictor;
|
||||
import com.stardust.autojs.yolo.onnx.OnnxYoloV8Predictor;
|
||||
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Size;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import ai.onnxruntime.providers.NNAPIFlags;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* @author TonyJiangWJ
|
||||
* @since 2023/8/20
|
||||
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
|
||||
* 适配旧版本脚本
|
||||
*/
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public class YoloV8Predictor extends YoloPredictor {
|
||||
private static final String TAG = "YoloV8Predictor";
|
||||
private static final Pattern IMG_SIZE_PATTERN = Pattern.compile("\\[(\\d+), \\d+]");
|
||||
private static final Pattern LABEL_PATTERN = Pattern.compile("'([^']*)'");
|
||||
|
||||
private final String modelPath;
|
||||
|
||||
private boolean tryNpu;
|
||||
private Size shapeSize = new Size(640, 640);
|
||||
private Letterbox letterbox;
|
||||
|
||||
private List<String> apiFlags = Arrays.asList("CPU_DISABLED");
|
||||
public class YoloV8Predictor extends OnnxYoloV8Predictor {
|
||||
|
||||
public YoloV8Predictor(String modelPath) {
|
||||
this.modelPath = modelPath;
|
||||
init = true;
|
||||
super(modelPath);
|
||||
}
|
||||
|
||||
public YoloV8Predictor(String modelPath, float confThreshold, float nmsThreshold) {
|
||||
this.modelPath = modelPath;
|
||||
this.confThreshold = confThreshold;
|
||||
this.nmsThreshold = nmsThreshold;
|
||||
init = true;
|
||||
super(modelPath, confThreshold, nmsThreshold);
|
||||
}
|
||||
|
||||
|
||||
public void setShapeSize(double width, double height) {
|
||||
this.shapeSize = new Size(width, height);
|
||||
}
|
||||
|
||||
public void setTryNpu(boolean tryNpu) {
|
||||
this.tryNpu = tryNpu;
|
||||
}
|
||||
|
||||
public void setApiFlags(List<String> apiFlags) {
|
||||
this.apiFlags = apiFlags;
|
||||
}
|
||||
|
||||
private OrtSession session;
|
||||
private OrtEnvironment environment;
|
||||
|
||||
private void prepareSession() throws OrtException {
|
||||
if (environment != null) {
|
||||
return;
|
||||
}
|
||||
// 加载ONNX模型
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
|
||||
addNNApiProvider(sessionOptions);
|
||||
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
// 输出基本信息
|
||||
session.getInputInfo().keySet().forEach(x -> {
|
||||
try {
|
||||
System.out.println("input name = " + x);
|
||||
System.out.println(session.getInputInfo().get(x).getInfo().toString());
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
// 如果入参labels无效或未定义,使用模型内置labels
|
||||
if (labels == null || labels.size() == 0) {
|
||||
labels = initLabels(session);
|
||||
}
|
||||
initShapeSize(session);
|
||||
}
|
||||
|
||||
private List<String> initLabels(OrtSession session) throws OrtException {
|
||||
String meteStr = session.getMetadata().getCustomMetadata().get("names");
|
||||
if (meteStr == null) {
|
||||
Log.d(TAG, "initLabels: 读取names失败 无法自动修正labels");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
String[] labels = new String[meteStr.split(",").length];
|
||||
|
||||
Matcher matcher = LABEL_PATTERN.matcher(meteStr);
|
||||
|
||||
int h = 0;
|
||||
while (matcher.find()) {
|
||||
labels[h] = matcher.group(1);
|
||||
h++;
|
||||
}
|
||||
return Arrays.asList(labels);
|
||||
}
|
||||
|
||||
private void initShapeSize(OrtSession session) throws OrtException {
|
||||
String meteStr = session.getMetadata().getCustomMetadata().get("imgsz");
|
||||
Log.d(TAG, "initShapeSize: " + meteStr);
|
||||
if (meteStr == null) {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz失败 无法自动修正输入大小");
|
||||
return;
|
||||
}
|
||||
Matcher matcher = IMG_SIZE_PATTERN.matcher(meteStr);
|
||||
if (matcher.find()) {
|
||||
String shapeSize = matcher.group(1);
|
||||
if (shapeSize == null) {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
|
||||
return;
|
||||
}
|
||||
this.shapeSize = new Size(Double.parseDouble(shapeSize), Double.parseDouble(shapeSize));
|
||||
Log.d(TAG, "set shape size: " + shapeSize);
|
||||
} else {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
|
||||
if (!tryNpu) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
List<NNAPIFlags> flags = new ArrayList<>();
|
||||
if (apiFlags.contains("USE_FP16")) {
|
||||
flags.add(NNAPIFlags.USE_FP16);
|
||||
}
|
||||
if (apiFlags.contains("USE_NCHW")) {
|
||||
flags.add(NNAPIFlags.USE_NCHW);
|
||||
}
|
||||
if (apiFlags.contains("CPU_ONLY")) {
|
||||
flags.add(NNAPIFlags.CPU_ONLY);
|
||||
}
|
||||
if (apiFlags.contains("CPU_DISABLED")) {
|
||||
flags.add(NNAPIFlags.CPU_DISABLED);
|
||||
}
|
||||
Log.d(TAG, "addNNApiProvider: 当前启用nnapiFlags:" + new Gson().toJson(apiFlags));
|
||||
sessionOptions.addNnapi(EnumSet.copyOf(flags));
|
||||
Log.d(TAG, "prepareSession: 启用nnapi成功");
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "prepareSession: 无法启用nnapi");
|
||||
}
|
||||
}
|
||||
|
||||
private HashMap<String, OnnxTensor> preprocessImage(Mat img) throws OrtException {
|
||||
|
||||
// 读取 image
|
||||
Mat image = img.clone();
|
||||
// 将四通道转换为三通道
|
||||
if (image.channels() == 4) {
|
||||
Imgproc.cvtColor(image, image, Imgproc.COLOR_RGBA2RGB);
|
||||
}
|
||||
Log.d(TAG, "preprocessImage: image's channels: " + image.channels());
|
||||
// 更改 image 尺寸
|
||||
letterbox = new Letterbox();
|
||||
letterbox.setNewShape(this.shapeSize);
|
||||
image = letterbox.letterbox(image);
|
||||
|
||||
int rows = letterbox.getHeight();
|
||||
int cols = letterbox.getWidth();
|
||||
int channels = image.channels();
|
||||
// 转换Mat对象的数据类型为CV_64F,即64位浮点型
|
||||
Mat convertedImage = new Mat();
|
||||
image.convertTo(convertedImage, CvType.CV_64F);
|
||||
|
||||
// 获取整个像素数据
|
||||
double[] pixelData = new double[rows * cols * channels];
|
||||
convertedImage.get(0, 0, pixelData);
|
||||
|
||||
float[] pixels = new float[channels * rows * cols];
|
||||
for (int i = 0; i < rows; i++) {
|
||||
for (int j = 0; j < cols; j++) {
|
||||
for (int k = 0; k < channels; k++) {
|
||||
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
|
||||
// 重新组织内存访问模式,提高缓存效率
|
||||
pixels[k * rows * cols + i * cols + j] = (float) (pixelData[(i * cols + j) * channels + k] / 255.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
image.release();
|
||||
convertedImage.release();
|
||||
// 创建OnnxTensor对象
|
||||
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
|
||||
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
|
||||
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
||||
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
|
||||
return stringOnnxTensorHashMap;
|
||||
}
|
||||
|
||||
private List<Detection> postProcessOutput(OrtSession.Result output) throws OrtException {
|
||||
float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
|
||||
|
||||
outputData = transposeMatrix(outputData);
|
||||
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
|
||||
|
||||
for (float[] bbox : outputData) {
|
||||
int label = argmax(bbox, 4); // 直接在原数组上进行操作
|
||||
float conf = bbox[label + 4];
|
||||
if (conf < confThreshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bbox[4] = conf;
|
||||
|
||||
// xywh to (x1, y1, x2, y2)
|
||||
xywh2xyxy(bbox);
|
||||
|
||||
// skip invalid predictions
|
||||
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
class2Bbox.computeIfAbsent(label, k -> new ArrayList<>()).add(bbox);
|
||||
}
|
||||
|
||||
List<Detection> detections = new ArrayList<>();
|
||||
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
|
||||
int label = entry.getKey();
|
||||
List<float[]> bboxes = entry.getValue();
|
||||
bboxes = nonMaxSuppression(bboxes, nmsThreshold);
|
||||
for (float[] bbox : bboxes) {
|
||||
String labelString = "";
|
||||
if (labels.size() - 1 < label) {
|
||||
labelString = String.valueOf(label);
|
||||
} else {
|
||||
labelString = labels.get(label);
|
||||
}
|
||||
detections.add(new Detection(labelString, entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
|
||||
}
|
||||
}
|
||||
return detections;
|
||||
}
|
||||
|
||||
public List<DetectResult> predictYolo(String imagePath) throws OrtException {
|
||||
return predictYolo(Imgcodecs.imread(imagePath));
|
||||
}
|
||||
|
||||
public List<DetectResult> predictYolo(Mat image) throws OrtException {
|
||||
prepareSession();
|
||||
long start_time = System.currentTimeMillis();
|
||||
Map<String, OnnxTensor> inputMap = preprocessImage(image);
|
||||
// 运行推理
|
||||
try (OrtSession.Result output = session.run(inputMap)) {
|
||||
Log.d(TAG, "predictYolo: onnx run cost " + (System.currentTimeMillis() - start_time) + "ms");
|
||||
List<Detection> detections = postProcessOutput(output);
|
||||
Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
|
||||
return detections.stream().map(detection -> new DetectResult(detection, letterbox))
|
||||
.collect(Collectors.toList());
|
||||
} finally {
|
||||
// 释放资源
|
||||
inputMap.values().forEach(OnnxTensor::close);
|
||||
}
|
||||
}
|
||||
|
||||
public static void xywh2xyxy(float[] bbox) {
|
||||
float x = bbox[0];
|
||||
float y = bbox[1];
|
||||
float w = bbox[2];
|
||||
float h = bbox[3];
|
||||
|
||||
bbox[0] = x - w * 0.5f;
|
||||
bbox[1] = y - h * 0.5f;
|
||||
bbox[2] = x + w * 0.5f;
|
||||
bbox[3] = y + h * 0.5f;
|
||||
}
|
||||
|
||||
public static float[][] transposeMatrix(float[][] m) {
|
||||
float[][] temp = new float[m[0].length][m.length];
|
||||
for (int i = 0; i < m.length; i++) {
|
||||
for (int j = 0; j < m[0].length; j++) {
|
||||
temp[j][i] = m[i][j];
|
||||
}
|
||||
}
|
||||
return temp;
|
||||
}
|
||||
|
||||
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
|
||||
long start = System.currentTimeMillis();
|
||||
List<float[]> bestBboxes = new ArrayList<>();
|
||||
|
||||
bboxes.sort(Comparator.comparing(a -> a[4]));
|
||||
|
||||
while (!bboxes.isEmpty()) {
|
||||
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
|
||||
bestBboxes.add(bestBbox);
|
||||
bboxes.removeIf(bbox -> computeIOU(bbox, bestBbox) >= iouThreshold);
|
||||
}
|
||||
Log.d(TAG, "nonMaxSuppression: cost " + (System.currentTimeMillis() - start) + "ms");
|
||||
return bestBboxes;
|
||||
}
|
||||
|
||||
public static float computeIOU(float[] box1, float[] box2) {
|
||||
|
||||
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
|
||||
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
|
||||
|
||||
float left = Math.max(box1[0], box2[0]);
|
||||
float top = Math.max(box1[1], box2[1]);
|
||||
float right = Math.min(box1[2], box2[2]);
|
||||
float bottom = Math.min(box1[3], box2[3]);
|
||||
|
||||
// 计算交集区域的宽度和高度
|
||||
float width = Math.max(right - left, 0);
|
||||
float height = Math.max(bottom - top, 0);
|
||||
|
||||
// 计算交集面积和并集面积
|
||||
float interArea = width * height;
|
||||
float unionArea = area1 + area2 - interArea;
|
||||
|
||||
// 计算交并比
|
||||
return Math.max(interArea / unionArea, 1e-8f);
|
||||
}
|
||||
|
||||
|
||||
//返回最大值的索引
|
||||
// 优化后的 argmax 函数
|
||||
public static int argmax(float[] a, int start) {
|
||||
float re = -Float.MAX_VALUE;
|
||||
int arg = -1;
|
||||
for (int i = start; i < a.length; i++) {
|
||||
if (a[i] >= re) {
|
||||
re = a[i];
|
||||
arg = i - start;
|
||||
}
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release() {
|
||||
if (session != null) {
|
||||
try {
|
||||
session.close();
|
||||
session = null;
|
||||
} catch (OrtException e) {
|
||||
Log.e(TAG, "close session failed" + e);
|
||||
}
|
||||
environment.close();
|
||||
environment = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,23 +1,15 @@
|
||||
package com.stardust.autojs.runtime.api;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import com.stardust.autojs.core.image.ImageWrapper;
|
||||
import com.stardust.autojs.ncnn.NcnnYoloV8Predictor;
|
||||
import com.stardust.autojs.onnx.YoloV8Predictor;
|
||||
import com.stardust.autojs.onnx.domain.DetectResult;
|
||||
import com.stardust.autojs.runtime.ScriptRuntime;
|
||||
import com.stardust.autojs.yolo.ModelInitParams;
|
||||
import com.stardust.autojs.yolo.YoloInstance;
|
||||
import com.stardust.autojs.yolo.ncnn.NcnnInitParams;
|
||||
import com.stardust.autojs.yolo.ncnn.NcnnYoloInstanceFactory;
|
||||
import com.stardust.autojs.yolo.onnx.OnnxYoloInstanceFactory;
|
||||
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Rect;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import ai.onnxruntime.OrtException;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
@ -28,96 +20,26 @@ import androidx.annotation.RequiresApi;
|
||||
public class Yolo {
|
||||
private static final String TAG = "Yolo";
|
||||
|
||||
public YoloInstance createNcnn(String paramPath, String binPath, List<String> labels, Integer imageSize, boolean useGpu) {
|
||||
return new YoloInstance() {
|
||||
|
||||
private NcnnYoloV8Predictor ncnnYoloV8 = new NcnnYoloV8Predictor(paramPath, binPath, labels);
|
||||
|
||||
{
|
||||
ncnnYoloV8.setShapeSize(imageSize);
|
||||
ncnnYoloV8.setUseGpu(useGpu);
|
||||
Log.d(TAG, "ncnnYoloV8 instance initializer: " + ncnnYoloV8.init());
|
||||
}
|
||||
|
||||
@Override
|
||||
public YoloPredictor getPredictor() {
|
||||
return ncnnYoloV8;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DetectResult> predictYolo(Mat image) {
|
||||
return ncnnYoloV8.predictYolo(image);
|
||||
}
|
||||
};
|
||||
}
|
||||
private final NcnnYoloInstanceFactory ncnnFactory = new NcnnYoloInstanceFactory();
|
||||
private final OnnxYoloInstanceFactory onnxFactory = new OnnxYoloInstanceFactory();
|
||||
|
||||
public YoloInstance createOnnx(String modelPath, List<String> labels, Integer imageSize) {
|
||||
return new YoloInstance() {
|
||||
private YoloV8Predictor onnxYoloV8 = new YoloV8Predictor(modelPath);
|
||||
|
||||
{
|
||||
onnxYoloV8.setLabels(labels);
|
||||
onnxYoloV8.setShapeSize(imageSize, imageSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public YoloPredictor getPredictor() {
|
||||
return onnxYoloV8;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DetectResult> predictYolo(Mat image) {
|
||||
try {
|
||||
return onnxYoloV8.predictYolo(image);
|
||||
} catch (OrtException e) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
};
|
||||
ModelInitParams params = new ModelInitParams();
|
||||
params.setModelPath(modelPath);
|
||||
params.setLabels(labels);
|
||||
params.setImageSize(imageSize);
|
||||
return onnxFactory.createInstance(params);
|
||||
}
|
||||
|
||||
|
||||
public static abstract class YoloInstance {
|
||||
|
||||
public abstract YoloPredictor getPredictor();
|
||||
|
||||
public abstract List<DetectResult> predictYolo(Mat image);
|
||||
|
||||
public void setConfThreshold(float confThreshold) {
|
||||
getPredictor().setConfThreshold(confThreshold);
|
||||
}
|
||||
|
||||
public void setNmsThreshold(float nmsThreshold) {
|
||||
getPredictor().setNmsThreshold(nmsThreshold);
|
||||
}
|
||||
|
||||
public boolean isInit() {
|
||||
return getPredictor().isInit();
|
||||
}
|
||||
|
||||
public void release() {
|
||||
getPredictor().release();
|
||||
}
|
||||
|
||||
public List<DetectResult> captureAndPredict(ScriptRuntime runtime, Rect rect) {
|
||||
Images images = (Images)runtime.getImages();
|
||||
Image image = images.captureScreenRaw();
|
||||
if (image != null) {
|
||||
ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC4);
|
||||
image.close();
|
||||
Mat mat = imageWrapper.getMat();
|
||||
if (rect != null) {
|
||||
// 裁切图像
|
||||
Mat croppedImage = new Mat(mat, rect);
|
||||
mat.release();
|
||||
mat = croppedImage;
|
||||
}
|
||||
List<DetectResult> results = this.predictYolo(mat);
|
||||
mat.release();
|
||||
return results;
|
||||
}
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
public YoloInstance createNcnn(String paramPath, String binPath, List<String> labels, Integer imageSize, boolean useGpu) {
|
||||
NcnnInitParams params = new NcnnInitParams();
|
||||
params.setParamPath(paramPath);
|
||||
params.setBinPath(binPath);
|
||||
params.setLabels(labels);
|
||||
params.setImageSize(imageSize);
|
||||
params.setUseGpu(useGpu);
|
||||
return ncnnFactory.createInstance(params);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,94 @@
|
||||
package com.stardust.autojs.yolo;
|
||||
|
||||
import android.util.Log;
|
||||
|
||||
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
|
||||
|
||||
import org.opencv.core.Mat;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* BaseYoloInstance类是一个实现了YoloInstance接口的具体类,用于封装YoloPredictor对象,
|
||||
* 并提供YOLO模型推理的核心功能,包括预测、设置阈值、检查初始化状态以及释放资源。
|
||||
*
|
||||
* @author TonyJiangWJ
|
||||
* @since 2025/1/5
|
||||
*/
|
||||
public class BaseYoloInstance extends YoloInstance {
|
||||
private final YoloPredictor predictor;
|
||||
|
||||
/**
|
||||
* 构造函数,初始化BaseYoloInstance实例。
|
||||
*
|
||||
* @param predictor YoloPredictor对象,用于执行YOLO模型的推理操作。
|
||||
*/
|
||||
public BaseYoloInstance(YoloPredictor predictor) {
|
||||
this.predictor = predictor;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前实例的YoloPredictor对象。
|
||||
*
|
||||
* @return 返回封装的YoloPredictor对象。
|
||||
*/
|
||||
@Override
|
||||
public YoloPredictor getPredictor() {
|
||||
return predictor;
|
||||
}
|
||||
|
||||
/**
|
||||
* 对输入的图像进行YOLO模型推理,返回检测结果列表。
|
||||
*
|
||||
* @param image 输入的图像数据,类型为Mat(通常来自OpenCV)。
|
||||
* @return 返回检测结果列表,如果推理失败则返回空列表。
|
||||
*/
|
||||
@Override
|
||||
public List<DetectResult> predictYolo(Mat image) {
|
||||
try {
|
||||
return predictor.predictYolo(image);
|
||||
} catch (Exception e) {
|
||||
Log.e("BaseYoloInstance", "predictYolo: failed", e);
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置YOLO模型的置信度阈值。
|
||||
*
|
||||
* @param confThreshold 置信度阈值,范围通常为0到1。
|
||||
*/
|
||||
@Override
|
||||
public void setConfThreshold(float confThreshold) {
|
||||
predictor.setConfThreshold(confThreshold);
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置YOLO模型的非极大值抑制(NMS)阈值。
|
||||
*
|
||||
* @param nmsThreshold NMS阈值,范围通常为0到1。
|
||||
*/
|
||||
@Override
|
||||
public void setNmsThreshold(float nmsThreshold) {
|
||||
predictor.setNmsThreshold(nmsThreshold);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查YOLO模型是否已经初始化。
|
||||
*
|
||||
* @return 如果模型已初始化,返回true;否则返回false。
|
||||
*/
|
||||
@Override
|
||||
public boolean isInit() {
|
||||
return predictor.isInit();
|
||||
}
|
||||
|
||||
/**
|
||||
* 释放YOLO模型占用的资源。
|
||||
*/
|
||||
@Override
|
||||
public void release() {
|
||||
predictor.release();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,40 @@
|
||||
package com.stardust.autojs.yolo;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* 用于存储模型初始化所需的参数
|
||||
*
|
||||
* @author TonyJiangWJ
|
||||
* @since 2025/1/5
|
||||
*/
|
||||
public class ModelInitParams {
|
||||
private String modelPath;
|
||||
private List<String> labels;
|
||||
private Integer imageSize;
|
||||
|
||||
public String getModelPath() {
|
||||
return modelPath;
|
||||
}
|
||||
|
||||
public void setModelPath(String modelPath) {
|
||||
this.modelPath = modelPath;
|
||||
}
|
||||
|
||||
public List<String> getLabels() {
|
||||
return labels;
|
||||
}
|
||||
|
||||
public void setLabels(List<String> labels) {
|
||||
this.labels = labels;
|
||||
}
|
||||
|
||||
public Integer getImageSize() {
|
||||
return imageSize;
|
||||
}
|
||||
|
||||
public void setImageSize(Integer imageSize) {
|
||||
this.imageSize = imageSize;
|
||||
}
|
||||
}
|
||||
102
autojs/src/main/java/com/stardust/autojs/yolo/YoloInstance.java
Normal file
102
autojs/src/main/java/com/stardust/autojs/yolo/YoloInstance.java
Normal file
@ -0,0 +1,102 @@
|
||||
package com.stardust.autojs.yolo;
|
||||
|
||||
|
||||
import android.media.Image;
|
||||
|
||||
import com.stardust.autojs.core.image.ImageWrapper;
|
||||
import com.stardust.autojs.runtime.ScriptRuntime;
|
||||
import com.stardust.autojs.runtime.api.Images;
|
||||
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
|
||||
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Rect;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* YoloInstance是一个抽象类,定义了YOLO实例的基本行为和功能。
|
||||
* 该类提供了YOLO模型推理的核心方法,包括预测、设置阈值、检查初始化状态、释放资源以及捕获屏幕并预测的功能。
|
||||
*
|
||||
* @author TonyJiangWJ
|
||||
* @since 2025/1/5
|
||||
*/
|
||||
public abstract class YoloInstance {
|
||||
|
||||
/**
|
||||
* 获取当前实例的YoloPredictor对象。
|
||||
*
|
||||
* @return 返回封装的YoloPredictor对象。
|
||||
*/
|
||||
public abstract YoloPredictor getPredictor();
|
||||
|
||||
/**
|
||||
* 对输入的图像进行YOLO模型推理,返回检测结果列表。
|
||||
*
|
||||
* @param image 输入的图像数据,类型为Mat(通常来自OpenCV)。
|
||||
* @return 返回检测结果列表。
|
||||
*/
|
||||
public abstract List<DetectResult> predictYolo(Mat image);
|
||||
|
||||
/**
|
||||
* 设置YOLO模型的置信度阈值。
|
||||
*
|
||||
* @param confThreshold 置信度阈值,范围通常为0到1。
|
||||
*/
|
||||
public void setConfThreshold(float confThreshold) {
|
||||
getPredictor().setConfThreshold(confThreshold);
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置YOLO模型的非极大值抑制(NMS)阈值。
|
||||
*
|
||||
* @param nmsThreshold NMS阈值,范围通常为0到1。
|
||||
*/
|
||||
public void setNmsThreshold(float nmsThreshold) {
|
||||
getPredictor().setNmsThreshold(nmsThreshold);
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查YOLO模型是否已经初始化。
|
||||
*
|
||||
* @return 如果模型已初始化,返回true;否则返回false。
|
||||
*/
|
||||
public boolean isInit() {
|
||||
return getPredictor().isInit();
|
||||
}
|
||||
|
||||
/**
|
||||
* 释放YOLO模型占用的资源。
|
||||
*/
|
||||
public void release() {
|
||||
getPredictor().release();
|
||||
}
|
||||
|
||||
/**
|
||||
* 捕获屏幕图像并进行YOLO模型推理。
|
||||
*
|
||||
* @param runtime 脚本运行时环境,用于获取图像捕获功能。
|
||||
* @param rect 指定捕获屏幕的区域,如果为null则捕获整个屏幕。
|
||||
* @return 返回检测结果列表,如果捕获或推理失败则返回空列表。
|
||||
*/
|
||||
public List<DetectResult> captureAndPredict(ScriptRuntime runtime, Rect rect) {
|
||||
Images images = (Images) runtime.getImages();
|
||||
Image image = images.captureScreenRaw();
|
||||
if (image != null) {
|
||||
ImageWrapper imageWrapper = ImageWrapper.ofImageByMat(image, CvType.CV_8UC4);
|
||||
image.close();
|
||||
Mat mat = imageWrapper.getMat();
|
||||
if (rect != null) {
|
||||
// 裁切图像
|
||||
Mat croppedImage = new Mat(mat, rect);
|
||||
mat.release();
|
||||
mat = croppedImage;
|
||||
}
|
||||
List<DetectResult> results = this.predictYolo(mat);
|
||||
mat.release();
|
||||
return results;
|
||||
}
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,19 @@
|
||||
package com.stardust.autojs.yolo;
|
||||
|
||||
/**
|
||||
* yolo实例抽象工厂,用于创建不同类型的yolo实例 目前支持ncnn和onnx的yolov8版本实例
|
||||
*
|
||||
* @param <P> 模型初始化参数
|
||||
* @author TonyJiangWJ
|
||||
* @since 2025/1/5
|
||||
*/
|
||||
public interface YoloInstanceFactory<P extends ModelInitParams> {
|
||||
/**
|
||||
* 创建yolo实例
|
||||
*
|
||||
* @param initParams 初始化参数
|
||||
* @return 返回yolo实例
|
||||
*/
|
||||
YoloInstance createInstance(P initParams);
|
||||
|
||||
}
|
||||
@ -1,8 +1,11 @@
|
||||
package com.stardust.autojs.runtime.api;
|
||||
package com.stardust.autojs.yolo;
|
||||
|
||||
import android.util.Log;
|
||||
|
||||
import com.stardust.autojs.core.opencv.OpenCVHelper;
|
||||
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
|
||||
|
||||
import org.opencv.core.Mat;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@ -10,7 +13,7 @@ import java.util.List;
|
||||
* @author TonyJiangWJ
|
||||
* @since 2024/6/1
|
||||
*/
|
||||
public class YoloPredictor {
|
||||
public abstract class YoloPredictor {
|
||||
|
||||
static {
|
||||
OpenCVHelper.initIfNeeded(null, () -> {
|
||||
@ -54,6 +57,8 @@ public class YoloPredictor {
|
||||
return init;
|
||||
}
|
||||
|
||||
public abstract List<DetectResult> predictYolo(Mat image) throws Exception;
|
||||
|
||||
public void release() {
|
||||
|
||||
}
|
||||
@ -0,0 +1,34 @@
|
||||
package com.stardust.autojs.yolo.ncnn;
|
||||
|
||||
import com.stardust.autojs.yolo.ModelInitParams;
|
||||
|
||||
public class NcnnInitParams extends ModelInitParams {
|
||||
private String paramPath;
|
||||
private String binPath;
|
||||
|
||||
private boolean useGpu;
|
||||
|
||||
public String getParamPath() {
|
||||
return paramPath;
|
||||
}
|
||||
|
||||
public void setParamPath(String paramPath) {
|
||||
this.paramPath = paramPath;
|
||||
}
|
||||
|
||||
public String getBinPath() {
|
||||
return binPath;
|
||||
}
|
||||
|
||||
public void setBinPath(String binPath) {
|
||||
this.binPath = binPath;
|
||||
}
|
||||
|
||||
public boolean isUseGpu() {
|
||||
return useGpu;
|
||||
}
|
||||
|
||||
public void setUseGpu(boolean useGpu) {
|
||||
this.useGpu = useGpu;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,27 @@
|
||||
package com.stardust.autojs.yolo.ncnn;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import com.stardust.autojs.yolo.BaseYoloInstance;
|
||||
import com.stardust.autojs.yolo.YoloInstance;
|
||||
import com.stardust.autojs.yolo.YoloInstanceFactory;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
public class NcnnYoloInstanceFactory implements YoloInstanceFactory<NcnnInitParams> {
|
||||
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
@Override
|
||||
public YoloInstance createInstance(NcnnInitParams initParams) {
|
||||
NcnnYoloV8Predictor predictor = new NcnnYoloV8Predictor(initParams.getParamPath(),
|
||||
initParams.getBinPath(),
|
||||
initParams.getLabels());
|
||||
predictor.setShapeSize(initParams.getImageSize());
|
||||
predictor.setUseGpu(initParams.isUseGpu());
|
||||
Log.d("NcnnYoloInstanceFactory", "ncnnYoloV8 instance initializer: " + predictor.init());
|
||||
return new BaseYoloInstance(predictor);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
package com.stardust.autojs.ncnn;
|
||||
package com.stardust.autojs.yolo.ncnn;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import com.google.android.gms.common.util.CollectionUtils;
|
||||
import com.stardust.autojs.onnx.domain.DetectResult;
|
||||
import com.stardust.autojs.runtime.api.YoloPredictor;
|
||||
import com.tony.yolov8ncnn.PredictResult;
|
||||
import com.stardust.autojs.yolo.YoloPredictor;
|
||||
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
|
||||
import com.tony.yolov8ncnn.NcnnPredictorNative;
|
||||
import com.tony.yolov8ncnn.PredictResult;
|
||||
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@ -22,6 +21,8 @@ import java.util.stream.Collectors;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Ncnn YoloV8推理器
|
||||
*
|
||||
* @author TonyJiangWJ
|
||||
* @since 2024/6/1
|
||||
*/
|
||||
@ -0,0 +1,34 @@
|
||||
package com.stardust.autojs.yolo.onnx;
|
||||
|
||||
|
||||
import android.os.Build;
|
||||
|
||||
import com.stardust.autojs.yolo.BaseYoloInstance;
|
||||
import com.stardust.autojs.yolo.ModelInitParams;
|
||||
import com.stardust.autojs.yolo.YoloInstance;
|
||||
import com.stardust.autojs.yolo.YoloInstanceFactory;
|
||||
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* OnnxYoloV8实例创建工厂
|
||||
*
|
||||
* @author TonyJiangWJ
|
||||
* @since 2025/1/5
|
||||
*/
|
||||
public class OnnxYoloInstanceFactory implements YoloInstanceFactory<ModelInitParams> {
|
||||
/**
|
||||
* 创建YoloInstance实例
|
||||
*
|
||||
* @param modelInitParams 初始化参数
|
||||
* @return
|
||||
*/
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
@Override
|
||||
public YoloInstance createInstance(ModelInitParams modelInitParams) {
|
||||
OnnxYoloV8Predictor predictor = new OnnxYoloV8Predictor(modelInitParams.getModelPath());
|
||||
predictor.setLabels(modelInitParams.getLabels());
|
||||
predictor.setShapeSize(modelInitParams.getImageSize(), modelInitParams.getImageSize());
|
||||
return new BaseYoloInstance(predictor);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,375 @@
|
||||
package com.stardust.autojs.yolo.onnx;
|
||||
|
||||
import android.os.Build;
|
||||
import android.util.Log;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.stardust.autojs.yolo.YoloPredictor;
|
||||
import com.stardust.autojs.yolo.onnx.domain.DetectResult;
|
||||
import com.stardust.autojs.yolo.onnx.domain.Detection;
|
||||
import com.stardust.autojs.yolo.onnx.util.Letterbox;
|
||||
|
||||
import org.opencv.core.CvType;
|
||||
import org.opencv.core.Mat;
|
||||
import org.opencv.core.Size;
|
||||
import org.opencv.imgcodecs.Imgcodecs;
|
||||
import org.opencv.imgproc.Imgproc;
|
||||
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import ai.onnxruntime.OnnxTensor;
|
||||
import ai.onnxruntime.OrtEnvironment;
|
||||
import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import ai.onnxruntime.providers.NNAPIFlags;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* @author TonyJiangWJ
|
||||
* @since 2023/8/20
|
||||
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
|
||||
*/
|
||||
@RequiresApi(api = Build.VERSION_CODES.N)
|
||||
public class OnnxYoloV8Predictor extends YoloPredictor {
|
||||
private static final String TAG = "YoloV8Predictor";
|
||||
private static final Pattern IMG_SIZE_PATTERN = Pattern.compile("\\[(\\d+), \\d+]");
|
||||
private static final Pattern LABEL_PATTERN = Pattern.compile("'([^']*)'");
|
||||
|
||||
private final String modelPath;
|
||||
|
||||
private boolean tryNpu;
|
||||
private Size shapeSize = new Size(640, 640);
|
||||
private Letterbox letterbox;
|
||||
|
||||
private List<String> apiFlags = Arrays.asList("CPU_DISABLED");
|
||||
|
||||
public OnnxYoloV8Predictor(String modelPath) {
|
||||
this.modelPath = modelPath;
|
||||
init = true;
|
||||
}
|
||||
|
||||
public OnnxYoloV8Predictor(String modelPath, float confThreshold, float nmsThreshold) {
|
||||
this.modelPath = modelPath;
|
||||
this.confThreshold = confThreshold;
|
||||
this.nmsThreshold = nmsThreshold;
|
||||
init = true;
|
||||
}
|
||||
|
||||
|
||||
public void setShapeSize(double width, double height) {
|
||||
this.shapeSize = new Size(width, height);
|
||||
}
|
||||
|
||||
public void setTryNpu(boolean tryNpu) {
|
||||
this.tryNpu = tryNpu;
|
||||
}
|
||||
|
||||
public void setApiFlags(List<String> apiFlags) {
|
||||
this.apiFlags = apiFlags;
|
||||
}
|
||||
|
||||
private OrtSession session;
|
||||
private OrtEnvironment environment;
|
||||
|
||||
private void prepareSession() throws OrtException {
|
||||
if (environment != null) {
|
||||
return;
|
||||
}
|
||||
// 加载ONNX模型
|
||||
environment = OrtEnvironment.getEnvironment();
|
||||
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
|
||||
addNNApiProvider(sessionOptions);
|
||||
|
||||
session = environment.createSession(modelPath, sessionOptions);
|
||||
// 输出基本信息
|
||||
session.getInputInfo().keySet().forEach(x -> {
|
||||
try {
|
||||
System.out.println("input name = " + x);
|
||||
System.out.println(session.getInputInfo().get(x).getInfo().toString());
|
||||
} catch (OrtException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
});
|
||||
// 如果入参labels无效或未定义,使用模型内置labels
|
||||
if (labels == null || labels.size() == 0) {
|
||||
labels = initLabels(session);
|
||||
}
|
||||
initShapeSize(session);
|
||||
}
|
||||
|
||||
private List<String> initLabels(OrtSession session) throws OrtException {
|
||||
String meteStr = session.getMetadata().getCustomMetadata().get("names");
|
||||
if (meteStr == null) {
|
||||
Log.d(TAG, "initLabels: 读取names失败 无法自动修正labels");
|
||||
return Collections.emptyList();
|
||||
}
|
||||
String[] labels = new String[meteStr.split(",").length];
|
||||
|
||||
Matcher matcher = LABEL_PATTERN.matcher(meteStr);
|
||||
|
||||
int h = 0;
|
||||
while (matcher.find()) {
|
||||
labels[h] = matcher.group(1);
|
||||
h++;
|
||||
}
|
||||
return Arrays.asList(labels);
|
||||
}
|
||||
|
||||
private void initShapeSize(OrtSession session) throws OrtException {
|
||||
String meteStr = session.getMetadata().getCustomMetadata().get("imgsz");
|
||||
Log.d(TAG, "initShapeSize: " + meteStr);
|
||||
if (meteStr == null) {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz失败 无法自动修正输入大小");
|
||||
return;
|
||||
}
|
||||
Matcher matcher = IMG_SIZE_PATTERN.matcher(meteStr);
|
||||
if (matcher.find()) {
|
||||
String shapeSize = matcher.group(1);
|
||||
if (shapeSize == null) {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
|
||||
return;
|
||||
}
|
||||
this.shapeSize = new Size(Double.parseDouble(shapeSize), Double.parseDouble(shapeSize));
|
||||
Log.d(TAG, "set shape size: " + shapeSize);
|
||||
} else {
|
||||
Log.d(TAG, "initShapeSize: 读取imgsz格式异常 无法自动修正输入大小");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void addNNApiProvider(OrtSession.SessionOptions sessionOptions) {
|
||||
if (!tryNpu) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
List<NNAPIFlags> flags = new ArrayList<>();
|
||||
if (apiFlags.contains("USE_FP16")) {
|
||||
flags.add(NNAPIFlags.USE_FP16);
|
||||
}
|
||||
if (apiFlags.contains("USE_NCHW")) {
|
||||
flags.add(NNAPIFlags.USE_NCHW);
|
||||
}
|
||||
if (apiFlags.contains("CPU_ONLY")) {
|
||||
flags.add(NNAPIFlags.CPU_ONLY);
|
||||
}
|
||||
if (apiFlags.contains("CPU_DISABLED")) {
|
||||
flags.add(NNAPIFlags.CPU_DISABLED);
|
||||
}
|
||||
Log.d(TAG, "addNNApiProvider: 当前启用nnapiFlags:" + new Gson().toJson(apiFlags));
|
||||
sessionOptions.addNnapi(EnumSet.copyOf(flags));
|
||||
Log.d(TAG, "prepareSession: 启用nnapi成功");
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "prepareSession: 无法启用nnapi");
|
||||
}
|
||||
}
|
||||
|
||||
private HashMap<String, OnnxTensor> preprocessImage(Mat img) throws OrtException {
|
||||
|
||||
// 读取 image
|
||||
Mat image = img.clone();
|
||||
// 将四通道转换为三通道
|
||||
if (image.channels() == 4) {
|
||||
Imgproc.cvtColor(image, image, Imgproc.COLOR_RGBA2RGB);
|
||||
}
|
||||
Log.d(TAG, "preprocessImage: image's channels: " + image.channels());
|
||||
// 更改 image 尺寸
|
||||
letterbox = new Letterbox();
|
||||
letterbox.setNewShape(this.shapeSize);
|
||||
image = letterbox.letterbox(image);
|
||||
|
||||
int rows = letterbox.getHeight();
|
||||
int cols = letterbox.getWidth();
|
||||
int channels = image.channels();
|
||||
// 转换Mat对象的数据类型为CV_64F,即64位浮点型
|
||||
Mat convertedImage = new Mat();
|
||||
image.convertTo(convertedImage, CvType.CV_64F);
|
||||
|
||||
// 获取整个像素数据
|
||||
double[] pixelData = new double[rows * cols * channels];
|
||||
convertedImage.get(0, 0, pixelData);
|
||||
|
||||
float[] pixels = new float[channels * rows * cols];
|
||||
for (int i = 0; i < rows; i++) {
|
||||
for (int j = 0; j < cols; j++) {
|
||||
for (int k = 0; k < channels; k++) {
|
||||
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
|
||||
// 重新组织内存访问模式,提高缓存效率
|
||||
pixels[k * rows * cols + i * cols + j] = (float) (pixelData[(i * cols + j) * channels + k] / 255.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
image.release();
|
||||
convertedImage.release();
|
||||
// 创建OnnxTensor对象
|
||||
long[] shape = {1L, (long) channels, (long) rows, (long) cols};
|
||||
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
|
||||
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
||||
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
|
||||
return stringOnnxTensorHashMap;
|
||||
}
|
||||
|
||||
private List<Detection> postProcessOutput(OrtSession.Result output) throws OrtException {
|
||||
float[][] outputData = ((float[][][]) output.get(0).getValue())[0];
|
||||
|
||||
outputData = transposeMatrix(outputData);
|
||||
Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
|
||||
|
||||
for (float[] bbox : outputData) {
|
||||
int label = argmax(bbox, 4); // 直接在原数组上进行操作
|
||||
float conf = bbox[label + 4];
|
||||
if (conf < confThreshold) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bbox[4] = conf;
|
||||
|
||||
// xywh to (x1, y1, x2, y2)
|
||||
xywh2xyxy(bbox);
|
||||
|
||||
// skip invalid predictions
|
||||
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
class2Bbox.computeIfAbsent(label, k -> new ArrayList<>()).add(bbox);
|
||||
}
|
||||
|
||||
List<Detection> detections = new ArrayList<>();
|
||||
for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
|
||||
int label = entry.getKey();
|
||||
List<float[]> bboxes = entry.getValue();
|
||||
bboxes = nonMaxSuppression(bboxes, nmsThreshold);
|
||||
for (float[] bbox : bboxes) {
|
||||
String labelString = "";
|
||||
if (labels.size() - 1 < label) {
|
||||
labelString = String.valueOf(label);
|
||||
} else {
|
||||
labelString = labels.get(label);
|
||||
}
|
||||
detections.add(new Detection(labelString, entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
|
||||
}
|
||||
}
|
||||
return detections;
|
||||
}
|
||||
|
||||
public List<DetectResult> predictYolo(String imagePath) throws OrtException {
|
||||
return predictYolo(Imgcodecs.imread(imagePath));
|
||||
}
|
||||
|
||||
public List<DetectResult> predictYolo(Mat image) throws OrtException {
|
||||
prepareSession();
|
||||
long start_time = System.currentTimeMillis();
|
||||
Map<String, OnnxTensor> inputMap = preprocessImage(image);
|
||||
// 运行推理
|
||||
try (OrtSession.Result output = session.run(inputMap)) {
|
||||
Log.d(TAG, "predictYolo: onnx run cost " + (System.currentTimeMillis() - start_time) + "ms");
|
||||
List<Detection> detections = postProcessOutput(output);
|
||||
Log.d("YoloV8Predictor", String.format("onnx predict cost: %d ms", (System.currentTimeMillis() - start_time)));
|
||||
return detections.stream().map(detection -> new DetectResult(detection, letterbox))
|
||||
.collect(Collectors.toList());
|
||||
} finally {
|
||||
// 释放资源
|
||||
inputMap.values().forEach(OnnxTensor::close);
|
||||
}
|
||||
}
|
||||
|
||||
public static void xywh2xyxy(float[] bbox) {
|
||||
float x = bbox[0];
|
||||
float y = bbox[1];
|
||||
float w = bbox[2];
|
||||
float h = bbox[3];
|
||||
|
||||
bbox[0] = x - w * 0.5f;
|
||||
bbox[1] = y - h * 0.5f;
|
||||
bbox[2] = x + w * 0.5f;
|
||||
bbox[3] = y + h * 0.5f;
|
||||
}
|
||||
|
||||
public static float[][] transposeMatrix(float[][] m) {
|
||||
float[][] temp = new float[m[0].length][m.length];
|
||||
for (int i = 0; i < m.length; i++) {
|
||||
for (int j = 0; j < m[0].length; j++) {
|
||||
temp[j][i] = m[i][j];
|
||||
}
|
||||
}
|
||||
return temp;
|
||||
}
|
||||
|
||||
public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
|
||||
long start = System.currentTimeMillis();
|
||||
List<float[]> bestBboxes = new ArrayList<>();
|
||||
|
||||
bboxes.sort(Comparator.comparing(a -> a[4]));
|
||||
|
||||
while (!bboxes.isEmpty()) {
|
||||
float[] bestBbox = bboxes.remove(bboxes.size() - 1);
|
||||
bestBboxes.add(bestBbox);
|
||||
bboxes.removeIf(bbox -> computeIOU(bbox, bestBbox) >= iouThreshold);
|
||||
}
|
||||
Log.d(TAG, "nonMaxSuppression: cost " + (System.currentTimeMillis() - start) + "ms");
|
||||
return bestBboxes;
|
||||
}
|
||||
|
||||
public static float computeIOU(float[] box1, float[] box2) {
|
||||
|
||||
float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
|
||||
float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
|
||||
|
||||
float left = Math.max(box1[0], box2[0]);
|
||||
float top = Math.max(box1[1], box2[1]);
|
||||
float right = Math.min(box1[2], box2[2]);
|
||||
float bottom = Math.min(box1[3], box2[3]);
|
||||
|
||||
// 计算交集区域的宽度和高度
|
||||
float width = Math.max(right - left, 0);
|
||||
float height = Math.max(bottom - top, 0);
|
||||
|
||||
// 计算交集面积和并集面积
|
||||
float interArea = width * height;
|
||||
float unionArea = area1 + area2 - interArea;
|
||||
|
||||
// 计算交并比
|
||||
return Math.max(interArea / unionArea, 1e-8f);
|
||||
}
|
||||
|
||||
|
||||
//返回最大值的索引
|
||||
// 优化后的 argmax 函数
|
||||
public static int argmax(float[] a, int start) {
|
||||
float re = -Float.MAX_VALUE;
|
||||
int arg = -1;
|
||||
for (int i = start; i < a.length; i++) {
|
||||
if (a[i] >= re) {
|
||||
re = a[i];
|
||||
arg = i - start;
|
||||
}
|
||||
}
|
||||
return arg;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void release() {
|
||||
this.init = false;
|
||||
if (session != null) {
|
||||
try {
|
||||
session.close();
|
||||
session = null;
|
||||
} catch (OrtException e) {
|
||||
Log.e(TAG, "close session failed" + e);
|
||||
}
|
||||
environment.close();
|
||||
environment = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,12 +1,13 @@
|
||||
package com.stardust.autojs.onnx.domain;
|
||||
package com.stardust.autojs.yolo.onnx.domain;
|
||||
|
||||
import android.graphics.Rect;
|
||||
import com.stardust.autojs.onnx.util.Letterbox;
|
||||
|
||||
import com.stardust.autojs.yolo.onnx.util.Letterbox;
|
||||
|
||||
/**
|
||||
* @author TonyJiangWJ
|
||||
* @since 2023/8/20
|
||||
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
|
||||
* transfer from <a href="https://gitee.com/agricultureiot/yolo-onnx-java">yolo-onnx-java</a>
|
||||
*/
|
||||
public class DetectResult {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package com.stardust.autojs.onnx.domain;
|
||||
package com.stardust.autojs.yolo.onnx.domain;
|
||||
|
||||
/**
|
||||
* @author TonyJiangWJ
|
||||
@ -15,14 +15,14 @@ public class Detection {
|
||||
public float confidence;
|
||||
|
||||
|
||||
public Detection(String label,Integer clsId, float[] bbox, float confidence){
|
||||
public Detection(String label, Integer clsId, float[] bbox, float confidence) {
|
||||
this.clsId = clsId;
|
||||
this.label = label;
|
||||
this.bbox = bbox;
|
||||
this.confidence = confidence;
|
||||
}
|
||||
|
||||
public Detection(){
|
||||
public Detection() {
|
||||
|
||||
}
|
||||
|
||||
@ -52,12 +52,12 @@ public class Detection {
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return " label="+label +
|
||||
" \t clsId="+clsId +
|
||||
" \t x0="+bbox[0] +
|
||||
" \t y0="+bbox[1] +
|
||||
" \t x1="+bbox[2] +
|
||||
" \t y1="+bbox[3] +
|
||||
" \t score="+confidence;
|
||||
return " label=" + label +
|
||||
" \t clsId=" + clsId +
|
||||
" \t x0=" + bbox[0] +
|
||||
" \t y0=" + bbox[1] +
|
||||
" \t x1=" + bbox[2] +
|
||||
" \t y1=" + bbox[3] +
|
||||
" \t score=" + confidence;
|
||||
}
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
package com.stardust.autojs.onnx.util;
|
||||
package com.stardust.autojs.yolo.onnx.util;
|
||||
|
||||
import org.opencv.core.Core;
|
||||
import org.opencv.core.Mat;
|
||||
@ -8,7 +8,7 @@ import org.opencv.imgproc.Imgproc;
|
||||
/**
|
||||
* @author TonyJiangWJ
|
||||
* @since 2023/8/20
|
||||
* transfer from https://gitee.com/agricultureiot/yolo-onnx-java
|
||||
* transfer from <a href="https://gitee.com/agricultureiot/yolo-onnx-java">yolo-onnx-java</a>
|
||||
*/
|
||||
public class Letterbox {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user