import os
import cv2
import numpy as np
# from utils import *
# import torch
from hobot_dnn import pyeasy_dnn as dnn

def get_hw(pro):
    if pro.layout == "NCHW":
        return pro.shape[2], pro.shape[3]
    else:
        return pro.shape[1], pro.shape[2]

def bgr_to_ycrcb(img):
    img = img.astype('uint8')
    B = img[0:1, :, :]
    G = img[1:2, :, :]
    R = img[2:3, :, :]
    # (B, G, R) = cv2.split(one)
    Y = 0.299 * R + 0.587 * G + 0.114 * B
    Cr = (R - Y) * 0.713 + 0.5
    Cb = (B - Y) * 0.564 + 0.5

    return Y, Cr, Cb
	# return numpy.stack([Y, Cr, Cb], axis=2)


def ycrcb_to_bgr(one):
	one = one.astype('uint8')
	Y, Cr, Cb = cv2.split(one)

	B = (Cb - 0.5) * 1. / 0.564 + Y
	R = (Cr - 0.5) * 1. / 0.713 + Y
	G = 1. / 0.587 * (Y - 0.299 * R - 0.114 * B)

	return cv2.merge([B, G, R])


# img_path 图像完整路径
visible_img_path = '/root/SeAFusion/before_calibration_img/visible/00001D.png'

infrared_img_path = '/root/SeAFusion/before_calibration_img/infrared/00001D.png'

# model_path 量化模型完整路径
model_path = '/root/SeAFusion/SeAFusionmodel.bin'

models = dnn.load(model_path)
model_h, model_w = get_hw(models[0].inputs[0].properties)



img_v = cv2.imread(visible_img_path) 
print(img_v.shape)
img_v = img_v.transpose((2,1,0)) # BGR, CHW
img_y, img_Cb, img_Cr = bgr_to_ycrcb(img_v)
img_y = np.expand_dims(img_y, axis=0)
# img_v = torch.Tensor(img_v)
print(img_y.shape)
# img_y, img_Cb, img_Cr = RGB2YCrCb(img_v)
# print(img_y.shape)
# img = img_y
# img_v = np.array(img_y)

img_i = cv2.imread(infrared_img_path, cv2.IMREAD_GRAYSCALE) 
img_i = img_i.transpose((1,0)) # RGB, CHW
img_i = np.expand_dims(img_i, axis=0)
img_i = np.expand_dims(img_i, axis=0)
print(img_i.shape)
# img_i = torch.Tensor(img_i)
img_i = img_i.astype('uint8')
img_i = np.array(img_i)


outputs = models[0].forward((img_v, img_i))

output = (outputs[0].buffer,)

