Skip to content

Commit

Permalink
undistort images on the fly 2
Browse files Browse the repository at this point in the history
  • Loading branch information
sirfoga committed Apr 12, 2021
1 parent 68ff015 commit 3333581
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion experiments/human36m/train/human36m_alg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ dataset:
max_n_views: 31
num_workers: 8

retain_every_n_frames_in_test: 1000
retain_every_n_frames_in_test: 500
21 changes: 12 additions & 9 deletions mvn/datasets/human36m.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def __init__(self,
f"[train={train}, test={test}] {labels_path} has {len(self)} samples, but '{pred_results_path}' " + \
f"has {len(self.keypoints_3d_pred)}. Did you follow all preprocessing instructions carefully?"

self.meshgrids = None # to undistort (optional), call 'make_undistort' before!
self.meshgrids = None

def __len__(self):
return len(self.labels['table'])

def __getitem__(self, idx, force_undistort=True):
def __getitem__(self, idx, meshgrids=None):
sample = defaultdict(list) # return value
shot = self.labels['table'][idx]

Expand Down Expand Up @@ -157,6 +157,7 @@ def __getitem__(self, idx, force_undistort=True):
image_path = os.path.join(
self.h36m_root, subject, action, 'imageSequence' + '-undistorted' * self.undistort_images,
camera_name, 'img_%06d.jpg' % (frame_idx + 1))
print('getting', image_path)

if not os.path.isfile(image_path):
print('%s doesn\'t exist' % image_path) # find them!
Expand Down Expand Up @@ -189,11 +190,11 @@ def __getitem__(self, idx, force_undistort=True):
sample['cameras'].append(retval_camera)
sample['proj_matrices'].append(retval_camera.projection)

if force_undistort and self.meshgrids:
if self.meshgrids:
print('getting undistorted image...')
available_cameras = list(range(len(self.labels['action_names'])))
for camera_idx, bbox in enumerate(shot['bbox_by_camera_tlbr']):
if bbox[2] == bbox[0]: # bbox is empty, which means that this camera is missing
if bbox[2] == bbox[0]: # bbox is empty, which means that this camera is missing
available_cameras.remove(camera_idx)

for i, (camera_idx, image) in enumerate(zip(available_cameras, sample['images'])):
Expand Down Expand Up @@ -222,9 +223,9 @@ def __getitem__(self, idx, force_undistort=True):
sample.default_factory = None
return sample

def make_undistort(self):
print("... computing distorted meshgrids")
def make_meshgrids(self):
print(" computing distorted meshgrids")

n_subjects = len(self.labels['subject_names'])
n_cameras = len(self.labels['camera_names'])
meshgrids = np.empty((n_subjects, n_cameras), dtype=object)
Expand All @@ -236,7 +237,8 @@ def make_undistort(self):
bboxes = self.labels['table']['bbox_by_camera_tlbr'][sample_idx]

if (bboxes[:, 2] - bboxes[:, 0]).min() > 0: # if == 0, then some camera is missing
sample = self.__getitem__(sample_idx, force_undistort=False)
print('getting sample')
sample = self.__getitem__(sample_idx)
assert len(sample['images']) == n_cameras

for camera_idx, (camera, image) in enumerate(zip(sample['cameras'], sample['images'])):
Expand Down Expand Up @@ -267,7 +269,8 @@ def make_undistort(self):
# cache (save) distortion maps
meshgrids[subject_idx, camera_idx] = cv2.convertMaps(meshgrid.reshape((h, w, 2)), None, cv2.CV_16SC2)

self.meshgrids = meshgrids
print('done meshgrids')
return meshgrids

def evaluate_using_per_pose_error(self, per_pose_error, split_by_subject):
def evaluate_by_actions(self, per_pose_error, mask=None):
Expand Down
2 changes: 1 addition & 1 deletion tools/run.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ conda activate ${KERNELS_DIR}/${KERNEL_NAME} # or source
which python # just to check

CODE_FOLDER="/home/stfo194b/tesi/learnable-triangulation-pytorch"
LOGS_FOLDER="/projects/p_humanpose/learnable-triangulation/logs"
LOGS_FOLDER="/scratch/ws/0/stfo194b-p_humanpose/learnable-triangulation-pytorch/logs"
EXP_CONFIG="experiments/human36m/train/human36m_alg.yaml"

cd ${CODE_FOLDER}
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def setup_human36m_dataloaders(config, is_train, distributed_train):
ignore_cameras=config.dataset.train.ignore_cameras if hasattr(config.dataset.train, "ignore_cameras") else [],
crop=config.dataset.train.crop if hasattr(config.dataset.train, "crop") else True,
)
train_dataset.make_undistort()
train_dataset.meshgrids = train_dataset.make_meshgrids()
print(" training dataset length:", len(train_dataset))

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed_train else None
Expand Down Expand Up @@ -106,8 +106,8 @@ def setup_human36m_dataloaders(config, is_train, distributed_train):
ignore_cameras=config.dataset.val.ignore_cameras if hasattr(config.dataset.val, "ignore_cameras") else [],
crop=config.dataset.val.crop if hasattr(config.dataset.val, "crop") else True,
)
val_dataset.make_undistort()
print(" validation dataset length:", len(train_dataset))
val_dataset.meshgrids = val_dataset.make_meshgrids()
print(" validation dataset length:", len(val_dataset))

val_dataloader = DataLoader(
val_dataset,
Expand Down

0 comments on commit 3333581

Please sign in to comment.