utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. # coding=utf-8
  2. import math
  3. from collections import Counter
  4. import numpy as np
  5. import cv2
  6. from imutils import auto_canny, contours
  7. from e import PolyNodeCountError
  8. from score import score
  9. from settings import CHOICES, SHEET_AREA_MIN_RATIO, PROCESS_BRIGHT_COLS, PROCESS_BRIGHT_ROWS, BRIGHT_VALUE, \
  10. CHOICE_COL_COUNT, CHOICES_PER_QUE, WHITE_RATIO_PER_CHOICE, MAYBE_MULTI_CHOICE_THRESHOLD, CHOICE_CNT_COUNT, test_ans, \
  11. ORIENT_CODE
  12. def get_corner_node_list(poly_node_list):
  13. """
  14. 获得多边形四个顶点的坐标
  15. :type poly_node_list: ndarray
  16. :return: tuple
  17. """
  18. center_y, center_x = (np.sum(poly_node_list, axis=0) / 4)[0]
  19. top_left = bottom_left = top_right = bottom_right = None
  20. for node in poly_node_list:
  21. x = node[0, 1]
  22. y = node[0, 0]
  23. if x < center_x and y < center_y:
  24. top_left = node
  25. elif x < center_x and y > center_y:
  26. bottom_left = node
  27. elif x > center_x and y < center_y:
  28. top_right = node
  29. elif x > center_x and y > center_y:
  30. bottom_right = node
  31. return top_left, bottom_left, top_right, bottom_right
  32. def detect_cnt_again(poly, base_img):
  33. """
  34. 继续检测已截取区域是否涵盖了答题卡区域
  35. :param poly: ndarray
  36. :param base_img: ndarray
  37. :return: ndarray
  38. """
  39. # 该多边形区域是否还包含答题卡区域的flag
  40. flag = False
  41. # 计算多边形四个顶点,并且截图,然后处理截取后的图片
  42. top_left, bottom_left, top_right, bottom_right = get_corner_node_list(poly)
  43. roi_img = get_roi_img(base_img, bottom_left, bottom_right, top_left, top_right)
  44. img = get_init_process_img(roi_img)
  45. # 获得面积最大的轮廓
  46. cnt = get_max_area_cnt(img)
  47. # 如果轮廓面积足够大,重新计算多边形四个顶点
  48. if cv2.contourArea(cnt) > roi_img.shape[0] * roi_img.shape[1] * SHEET_AREA_MIN_RATIO:
  49. flag = True
  50. poly = cv2.approxPolyDP(cnt, cv2.arcLength((cnt,), True) * 0.1, True)
  51. top_left, bottom_left, top_right, bottom_right = get_corner_node_list(poly)
  52. if not poly.shape[0] == 4:
  53. raise PolyNodeCountError
  54. # 多边形顶点和图片顶点,主要用于纠偏
  55. base_poly_nodes = np.float32([top_left[0], bottom_left[0], top_right[0], bottom_right[0]])
  56. base_nodes = np.float32([[0, 0],
  57. [base_img.shape[1], 0],
  58. [0, base_img.shape[0]],
  59. [base_img.shape[1], base_img.shape[0]]])
  60. transmtx = cv2.getPerspectiveTransform(base_poly_nodes, base_nodes)
  61. if flag:
  62. img_warp = cv2.warpPerspective(roi_img, transmtx, (base_img.shape[1], base_img.shape[0]))
  63. else:
  64. img_warp = cv2.warpPerspective(base_img, transmtx, (base_img.shape[1], base_img.shape[0]))
  65. return img_warp
  66. def get_init_process_img(roi_img):
  67. """
  68. 对图片进行初始化处理,包括,梯度化,高斯模糊,二值化,腐蚀,膨胀和边缘检测
  69. :param roi_img: ndarray
  70. :return: ndarray
  71. """
  72. h = cv2.Sobel(roi_img, cv2.CV_32F, 0, 1, -1)
  73. v = cv2.Sobel(roi_img, cv2.CV_32F, 1, 0, -1)
  74. img = cv2.add(h, v)
  75. img = cv2.convertScaleAbs(img)
  76. img = cv2.GaussianBlur(img, (3, 3), 0)
  77. ret, img = cv2.threshold(img, 120, 255, cv2.THRESH_BINARY)
  78. kernel = np.ones((1, 1), np.uint8)
  79. img = cv2.erode(img, kernel, iterations=1)
  80. img = cv2.dilate(img, kernel, iterations=2)
  81. img = cv2.erode(img, kernel, iterations=1)
  82. img = cv2.dilate(img, kernel, iterations=2)
  83. img = auto_canny(img)
  84. return img
  85. def get_roi_img(base_img, bottom_left, bottom_right, top_left, top_right):
  86. """
  87. 截取合适的图片区域
  88. :param base_img: ndarray
  89. :param bottom_left: ndarray
  90. :param bottom_right: ndarray
  91. :param top_left: ndarray
  92. :param top_right: ndarray
  93. :return: ndarray
  94. """
  95. min_v = top_left[0, 1] if top_left[0, 1] < bottom_left[0, 1] else bottom_left[0, 1]
  96. max_v = top_right[0, 1] if top_right[0, 1] > bottom_right[0, 1] else bottom_right[0, 1]
  97. min_h = top_left[0, 0] if top_left[0, 0] < top_right[0, 0] else top_right[0, 0]
  98. max_h = bottom_left[0, 0] if bottom_left[0, 0] > bottom_right[0, 0] else bottom_right[0, 0]
  99. roi_img = base_img[min_v + 10:max_v - 10, min_h + 10:max_h - 10]
  100. return roi_img
  101. def get_max_area_cnt(img):
  102. """
  103. 获得图片里面最大面积的轮廓
  104. :param img: ndarray
  105. :return: ndarray
  106. """
  107. cnts, hierarchy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  108. cnt = max(cnts, key=lambda c: cv2.contourArea(c))
  109. return cnt
  110. def get_ans(ans_img, rows):
  111. # 选项个数加上题号
  112. interval = get_item_interval()
  113. my_score = 0
  114. items_per_row = get_items_per_row()
  115. ans = []
  116. for i, row in enumerate(rows):
  117. # 从左到右为当前题目的气泡轮廓排序,然后初始化被涂画的气泡变量
  118. for k in range(items_per_row):
  119. print '======================================='
  120. percent_list = []
  121. for j, c in enumerate(row[1 + k * interval:interval + k * interval]):
  122. try:
  123. # 获得选项框的区域
  124. new = ans_img[c[1]:(c[1] + c[3]), c[0]:(c[0] + c[2])]
  125. # 计算白色像素个数和所占百分比
  126. white_count = np.count_nonzero(new)
  127. percent = white_count * 1.0 / new.size
  128. except IndexError:
  129. percent = 1
  130. percent_list.append({'col': k + 1, 'row': i + 1, 'percent': percent, 'choice': CHOICES[j]})
  131. percent_list.sort(key=lambda x: x['percent'])
  132. choice_pos_n_ans = [percent_list[0]['row'], percent_list[0]['col']]
  133. choice_pos = (percent_list[0]['row'], percent_list[0]['col'])
  134. # if percent_list[1]['percent'] < 0.6 or (percent_list[1]['percent'] < WHITE_RATIO_PER_CHOICE and \
  135. # abs(percent_list[1]['percent'] - percent_list[0]['percent']) < MAYBE_MULTI_CHOICE_THRESHOLD):
  136. # print u'第%s排第%s列的作答:可能多涂了选项' % choice_pos
  137. # print u"第%s排第%s列的作答:%s" % choice_pos_n_ans
  138. # ans.append(percent_list[0]['choice'])
  139. _ans = ''
  140. for percent in percent_list:
  141. if percent['percent'] < 0.8:
  142. _ans += percent['choice']
  143. ans.append(''.join(sorted(list(_ans))))
  144. choice_pos_n_ans.append(''.join(sorted(list(_ans))))
  145. print u'第{0}排第{1}列的作答:{2}'.format(*choice_pos_n_ans)
  146. # elif percent_list[0]['percent'] < WHITE_RATIO_PER_CHOICE:
  147. # # key = (percent_list[0]['row'] - 1) * 3 + percent_list[0]['col']
  148. # # my_score += 1 if score.get(key) == percent_list[0]['choice'] else 0
  149. # # print 1 if score.get(key) == percent_list[0]['choice'] else 0
  150. # print u"第%s排第%s列的作答:%s" % choice_pos_n_ans
  151. # print percent_list[0]['percent']
  152. # ans.append(percent_list[0]['choice'])
  153. # else:
  154. # print u"第%s排第%s列的作答:可能没有填涂" % choice_pos
  155. # print percent_list[0]['percent']
  156. # ans.append(None)
  157. print '=====总分========'
  158. return rows, test_is_eq(ans, test_ans)
  159. def test_is_eq(ans, test_ans):
  160. count = 0
  161. for i, a in enumerate(ans):
  162. if a != test_ans[i]:
  163. print i / 4 + 1, i % 4, a
  164. count += 1
  165. if count:
  166. return False, count
  167. return True, count
  168. def get_items_per_row():
  169. items_per_row = CHOICE_COL_COUNT / (CHOICES_PER_QUE + 1)
  170. return items_per_row
  171. def get_item_interval():
  172. interval = CHOICES_PER_QUE + 1
  173. return interval
  174. def delete_rect(cents_pos, que_cnts):
  175. count = 0
  176. for i, c in enumerate(cents_pos):
  177. area_ration = cv2.contourArea(que_cnts[i - count]) / (c[2] * c[3])
  178. ratio = 1.0 * c[2] / c[3]
  179. if 0.5 > ratio or ratio > 2 or area_ration < 0.5:
  180. que_cnts.pop(i - count)
  181. count += 1
  182. return que_cnts
  183. def get_left_right(cnts):
  184. sort_res = contours.sort_contours(cnts, method="top-to-bottom")
  185. cents_pos = sort_res[1]
  186. que_cnts = list(sort_res[0])
  187. que_cnts = delete_rect(cents_pos, que_cnts)
  188. sort_res = contours.sort_contours(que_cnts, method="top-to-bottom")
  189. cents_pos = sort_res[1]
  190. que_cnts = list(sort_res[0])
  191. num = len(cents_pos) - CHOICE_COL_COUNT + 1
  192. dt = {}
  193. for i in range(num):
  194. distance = 0
  195. for j in range(i, i + CHOICE_COL_COUNT - 1):
  196. distance += cents_pos[j + 1][1] - cents_pos[j][1]
  197. dt[distance] = cents_pos[i:i + CHOICE_COL_COUNT]
  198. keys = dt.keys()
  199. key_min = min(keys)
  200. if key_min >= 10:
  201. raise
  202. w = sorted(dt[key_min], key=lambda x: x[0])
  203. lt, rt = w[0][0] - w[0][2] * 0.5, w[-1][0] + w[-1][2] * 0.5
  204. count = 0
  205. for i, c in enumerate(cents_pos):
  206. if c[0] < lt or c[0] > rt:
  207. que_cnts.pop(i - count)
  208. count += 1
  209. return que_cnts
  210. def get_top_bottom(cnts):
  211. sort_res = contours.sort_contours(cnts, method="left-to-right")
  212. cents_pos = sort_res[1]
  213. que_cnts = list(sort_res[0])
  214. choice_row_count = get_choice_row_count()
  215. num = len(cents_pos) - choice_row_count + 1
  216. dt = {}
  217. for i in range(num):
  218. distance = 0
  219. for j in range(i, i + choice_row_count - 1):
  220. distance += cents_pos[j + 1][0] - cents_pos[j][0]
  221. dt[distance] = cents_pos[i:i + choice_row_count]
  222. keys = dt.keys()
  223. key_min = min(keys)
  224. if key_min >= 10:
  225. raise
  226. w = sorted(dt[key_min], key=lambda x: x[1])
  227. top, bottom = w[0][1] - w[0][3] * 0.5, w[-1][1] + w[-1][3] * 0.5
  228. count = 0
  229. for i, c in enumerate(cents_pos):
  230. if c[1] < top or c[1] > bottom:
  231. que_cnts.pop(i - count)
  232. count += 1
  233. return que_cnts
  234. def get_choice_row_count():
  235. choice_row_count = int(math.ceil(CHOICE_CNT_COUNT * 1.0 / CHOICE_COL_COUNT))
  236. return choice_row_count
  237. def sort_by_row(cnts_pos):
  238. choice_row_count = get_choice_row_count()
  239. count = 0
  240. rows = []
  241. threshold = get_min_row_interval(cnts_pos)
  242. for i in range(choice_row_count):
  243. cols = cnts_pos[i * CHOICE_COL_COUNT - count:(i + 1) * CHOICE_COL_COUNT - count]
  244. # threshold = _std_plus_mean(cols)
  245. temp_row = [cols[0]]
  246. for j, col in enumerate(cols[1:]):
  247. if col[1] - cols[j - 1][1] < threshold:
  248. temp_row.append(col)
  249. else:
  250. break
  251. count += CHOICE_COL_COUNT - len(temp_row)
  252. temp_row.sort(key=lambda x: x[0])
  253. rows.append(temp_row)
  254. # insert_no_full_row(rows)
  255. ck_full_rows_size(rows)
  256. return rows
  257. def sort_by_col(cnts_pos):
  258. # TODO
  259. cnts_pos.sort(key=lambda x: x[0])
  260. choice_row_count = get_choice_row_count()
  261. count = 0
  262. cols = []
  263. threshold = get_min_col_interval(cnts_pos)
  264. for i in range(CHOICE_COL_COUNT):
  265. rows = cnts_pos[i * choice_row_count - count:(i + 1) * choice_row_count - count]
  266. temp_col = [rows[0]]
  267. for j, row in enumerate(rows[1:]):
  268. if row[0] - rows[j - 1][0] < threshold:
  269. temp_col.append(row)
  270. else:
  271. break
  272. count += choice_row_count - len(temp_col)
  273. temp_col.sort(key=lambda x: x[1])
  274. cols.append(temp_col)
  275. ck_full_cols_size(cols)
  276. return cols
  277. def insert_null_2_rows(cols, rows):
  278. temp = {}
  279. for i, row in enumerate(rows):
  280. for j, col in enumerate(cols):
  281. try:
  282. if row[j] != col[0]:
  283. row.insert(j, (col[1][0], row[j][1], col[1][2], row[j][3]))
  284. else:
  285. temp[j] = col.pop(0)
  286. except IndexError:
  287. try:
  288. row.insert(j, (col[1][0], row[j - 1][1], col[1][2], row[j - 1][3]))
  289. except IndexError:
  290. row.insert(j, (temp[j][0], row[j - 1][1], temp[j][2], row[j - 1][3]))
  291. def get_min_row_interval(cnts_pos):
  292. choice_row_count = get_choice_row_count()
  293. rows_interval = []
  294. for i, c in enumerate(cnts_pos[1:]):
  295. rows_interval.append(c[1] - cnts_pos[i][1])
  296. rows_interval.sort(reverse=True)
  297. return min(rows_interval[:choice_row_count - 1])
  298. def get_min_col_interval(cnts_pos):
  299. cols_interval = []
  300. for i, c in enumerate(cnts_pos[1:]):
  301. cols_interval.append(c[0] - cnts_pos[i][0])
  302. cols_interval.sort(reverse=True)
  303. return min(cols_interval[:CHOICE_COL_COUNT - 1])
  304. def get_min_interval(cnts_pos, orient):
  305. idx = ORIENT_CODE[orient]
  306. interval_list = []
  307. def ck_full_rows_size(rows):
  308. count = 0
  309. for row in rows:
  310. if len(row) == CHOICE_COL_COUNT:
  311. count += 1
  312. if count < 1:
  313. raise
  314. def ck_full_cols_size(rows):
  315. choice_row_count = get_choice_row_count()
  316. count = 0
  317. for row in rows:
  318. if len(row) == choice_row_count:
  319. count += 1
  320. if count < 1:
  321. raise
  322. def get_vertical_projective(img):
  323. w = [0] * img.shape[1]
  324. for x in range(img.shape[1]):
  325. for y in range(img.shape[0]):
  326. t = cv2.cv.Get2D(cv2.cv.fromarray(img), y, x)
  327. if t[0] == 255:
  328. w[x] += 1
  329. # show_fuck(img, w)
  330. seg(w, img)
  331. return w
  332. def get_h_projective(img):
  333. h = [0] * img.shape[0]
  334. for y in range(img.shape[0]):
  335. for x in range(img.shape[1]):
  336. s = cv2.cv.Get2D(cv2.cv.fromarray(img), y, x)
  337. if s[0] == 255:
  338. h[y] += 1
  339. painty = np.zeros(img.shape, np.uint8)
  340. painty = painty + 255
  341. for y in range(img.shape[0]):
  342. for x in range(h[y]):
  343. cv2.cv.Set2D(cv2.cv.fromarray(painty), y, x, (0, 0, 0, 0))
  344. cv2.imshow('painty', painty)
  345. cv2.waitKey(0)
  346. def show_fuck(img, w):
  347. paintx = np.zeros(img.shape, np.uint8)
  348. paintx = paintx + 255
  349. for x in range(img.shape[1]):
  350. for y in range(w[x]):
  351. # 把为0的像素变成白
  352. cv2.cv.Set2D(cv2.cv.fromarray(paintx), y, x, (0, 0, 0, 0))
  353. # 显示图片
  354. cv2.imshow('paintx', paintx)
  355. cv2.waitKey(0)
  356. def seg(w, img):
  357. dt = Counter(w)
  358. counts = np.array(dt.values())
  359. mean = np.mean(counts)
  360. std = np.std(counts)
  361. _w = np.array(w)
  362. _w[_w <= mean + std * 2] =0
  363. count = 0
  364. for i, _ in enumerate(_w[1:]):
  365. if _ > 0 and _w[i] == 0:
  366. count += 1
  367. if _w[-1] > 0:
  368. count -= 1
  369. print count
  370. if count != 18:
  371. show_fuck(img, w)
  372. show_fuck(img, _w)
  373. print '1'
  374. return _w