import os
import cv2
import numpy as np
from utils import *
import torch


src_root = '/data/horizon_x3/codes/SeAFusion/before_calibration_img/infrared'
cal_img_num = 49  # 想要的图像个数
dst_root = '/data/horizon_x3/codes/SeAFusion/calibration_data/infrared'


## 1. 从原始图像文件夹中获取100个图像作为校准数据
num_count = 0
img_names = []
for src_name in sorted(os.listdir(src_root)):
    if num_count > cal_img_num:
        break
    img_names.append(src_name)
    num_count += 1

# 检查目标文件夹是否存在，如果不存在就创建
if not os.path.exists(dst_root):
    os.system('mkdir {0}'.format(dst_root))


# img_path = os.path.join(src_root, img_names[0])
# img = cv2.imread(img_path)
# print(img.shape)


# 2.2 开始转换
for each_imgname in img_names:
    img_path = os.path.join(src_root, each_imgname)

    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # BRG, HWC
    
    
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # RGB, HWC
    img = img.transpose((1,0)) # RGB, CHW
    img = np.expand_dims(img, axis=0)
    img = torch.Tensor(img)
    print(img.shape)
    # img_y, img_Cb, img_Cr = RGB2YCrCb(img)
    # print(img_y.shape)
    # img = img_y
    img = np.array(img)

    # 将图像保存到目标文件夹下
    dst_path = os.path.join(dst_root, each_imgname + '.infrared')
    print("write:{0}, shape: {1}".format(dst_path, img.shape))
    # img.numpy().astype(np.uint8).tofile(dst_path)
    img.astype(np.uint8).tofile(dst_path)

    # data = np.fromfile(dst_path)
    # print(data.shape)
    # exit()

print('finish')