淘先锋技术网

首页 1 2 3 4 5 6 7
  • 相关参数改成自己的
  • 注意输出的shape,onnx有时候会合并输出,目前不知道什么原因
import onnxruntime   
import cv2, torch
import numpy as np

anchors=[[22, 55,  38, 44,  32, 88,  70, 56], [22, 22,  17, 35,  37, 24,  27, 34]]
ans=len(anchors[0])//2
cls_num=1
conf_th=0.25
nms_th=0.45
model_w=
model_h=

def get_boxes(output, anchors):
	h=output.size(2)
	w=output.size(3)
	output=output.view(ans,int(cls_num+5),h,w).permute(0,2,3,1).contiguous()
	# conf
	conf = torch.sigmoid(output[..., 4])
	cl = torch.sigmoid(output[..., 5:])
	#conf = output[..., 4]
	#cl = output[..., 5:]
	clv, cli = torch.max(cl, -1)
	#print(conf[conf>0.5])
	conf = conf * clv
	#print(conf[conf>0.15])
	mask = conf > conf_th
	conf = conf[mask].unsqueeze(-1)
	cli = cli[mask].unsqueeze(-1)
	# grid
	FloatTensor = torch.cuda.FloatTensor if conf.is_cuda else torch.FloatTensor
	grid_h, grid_w = torch.meshgrid(torch.arange(h), torch.arange(w))
	grid_h = grid_h.repeat(ans,1,1).type(FloatTensor)
	grid_w = grid_w.repeat(ans,1,1).type(FloatTensor)
	tx = (torch.sigmoid(output[..., 0]) + grid_w) / w
	ty = (torch.sigmoid(output[..., 1]) + grid_h) / h
	tx = tx[mask].unsqueeze(-1)
	ty = ty[mask].unsqueeze(-1)
	# anchor
	aw = torch.Tensor(anchors[0::2]).view(ans,1).repeat(1,h*w).view(ans,h,w).type(FloatTensor)
	ah = torch.Tensor(anchors[1::2]).view(ans,1).repeat(1,h*w).view(ans,h,w).type(FloatTensor)
	tw = torch.exp(output[..., 2]) * aw
	th = torch.exp(output[..., 3]) * ah
	tw = tw[mask].unsqueeze(-1)
	th = th[mask].unsqueeze(-1)
	return torch.cat([tx, ty, tw, th, cli, conf], -1)
	
def iou(a,b):
	A=len(a)
	B=len(b)
	area1=a[:,2]*a[:,3]
	area1=area1.unsqueeze(1).expand(A,B)
	area2=b[:,2]*b[:,3]
	area2=area2.unsqueeze(0).expand(A,B)
	ba=torch.zeros(a.shape).cuda()
	bb=torch.zeros(b.shape).cuda()
	ba[:,0:2]=a[:,0:2]-a[:,2:]/2.0
	ba[:,2:]=ba[:,0:2]+a[:,2:]
	bb[:,0:2]=b[:,0:2]-b[:,2:]/2.0
	bb[:,2:]=bb[:,0:2]+b[:,2:]
	ba=ba.unsqueeze(1).expand(A,B,4)
	bb=bb.unsqueeze(0).expand(A,B,4)
	lt=torch.max(ba[:,:,0:2], bb[:,:,0:2])
	rb=torch.min(ba[:,:,2:], bb[:,:,2:])
	inter=torch.clamp((rb-lt),min=0)
	inter=inter[:,:,0]*inter[:,:,1]
	return inter/(area1+area2-inter)

def nms(box):
	box = box[torch.argsort(box[:,-1])]
	result=[]
	while len(box) > 0:
		result.append(box[0])
		if len(box) == 1: break
		ious=iou(box[0:1, 0:4], box[1:, 0:4])
		#print(ious)
		box=box[1:][ious.squeeze(0) < nms_th]
	return torch.stack(result)

def deal(boxes):
	labels = boxes[:, -2].unique()
	result=[]
	for l in labels:
		box = boxes[boxes[:, -2]==l]
		box = nms(box)
		for b in box: 
			result.append(b)
	return torch.stack(result)

session = onnxruntime.InferenceSession('.onnx', None)
input_name = session.get_inputs()[0].name
for node in session.get_outputs():
	print(node.name)
for line in open("jpg.txt"):
	line=line.strip()
	print(line)
	raw=cv2.imread(line)
	ih, iw, _ = raw.shape
	im=cv2.resize(raw, (model_w, model_h))
	im=im[..., ::-1]
	im=im.astype('float32')/255.0
	image = np.expand_dims(np.transpose(im, (2, 0, 1)), 0)
	outputs = session.run(None, {input_name: image})
	print(type(outputs))
	thld_boxes=[]
	for i,output in enumerate(outputs):
		output=torch.from_numpy(output).cuda()
		print(output.shape)
		boxes = get_boxes(output, anchors[i])
		if len(boxes) == 0: continue
		boxes[:,0] = boxes[:,0] * iw
		boxes[:,1] = boxes[:,1] * ih
		boxes[:,2] = boxes[:,2] / model_w * iw
		boxes[:,3] = boxes[:,3] / model_h * ih
		for b in boxes: thld_boxes.append(b)
	if len(thld_boxes) != 0: 
		boxes=deal(torch.stack(thld_boxes))
		print(len(boxes))
		for b in boxes:
			cx = b[0]
			cy = b[1]
			w = b[2]
			h = b[3]
			cv2.rectangle(raw, (int(cx-w/2), int(cy-h/2)), (int(cx+w/2), int(cy+h/2)), (0,0,255))
	cv2.imwrite("result/"+line.split('/')[-1], raw)