Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
sirfoga committed Apr 12, 2021
1 parent 3333581 commit b1d9a15
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 72 deletions.
3 changes: 2 additions & 1 deletion experiments/human36m/train/human36m_alg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ opt:
mse_smooth_threshold: 400

n_objects_per_epoch: 15000
n_epochs: 50
n_epochs: 200
n_epochs_long: 1000

batch_size: 8
val_batch_size: 16
Expand Down
28 changes: 16 additions & 12 deletions mvn/datasets/human36m.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@ def __init__(self,
for camera_idx in self.ignore_cameras
)

train_subjects = ['S1', 'S6', 'S7', 'S8'] # todo solve missing images in 'S5'
test_subjects = ['S9', 'S11']

train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects)
test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects)
train_subjects = [
self.labels['subject_names'].index(x)
for x in ['S1', 'S6', 'S7', 'S8'] # todo solve missing images in 'S5'
]
test_subjects = [
self.labels['subject_names'].index(x)
for x in ['S9', 'S11']
]

indices = []

Expand Down Expand Up @@ -109,6 +112,7 @@ def __init__(self,
indices.append(np.nonzero(mask)[0][::retain_every_n_frames_in_test])

self.labels['table'] = self.labels['table'][np.concatenate(indices)]
self.indices = indices

self.num_keypoints = 16 if kind == "mpii" else 17
assert self.labels['table']['keypoints'].shape[1] == 17, "Use a newer 'labels' file"
Expand All @@ -128,7 +132,7 @@ def __init__(self,
def __len__(self):
return len(self.labels['table'])

def __getitem__(self, idx, meshgrids=None):
def __getitem__(self, idx):
sample = defaultdict(list) # return value
shot = self.labels['table'][idx]

Expand Down Expand Up @@ -157,7 +161,6 @@ def __getitem__(self, idx, meshgrids=None):
image_path = os.path.join(
self.h36m_root, subject, action, 'imageSequence' + '-undistorted' * self.undistort_images,
camera_name, 'img_%06d.jpg' % (frame_idx + 1))
print('getting', image_path)

if not os.path.isfile(image_path):
print('%s doesn\'t exist' % image_path) # find them!
Expand Down Expand Up @@ -191,7 +194,6 @@ def __getitem__(self, idx, meshgrids=None):
sample['proj_matrices'].append(retval_camera.projection)

if self.meshgrids:
print('getting undistorted image...')
available_cameras = list(range(len(self.labels['action_names'])))
for camera_idx, bbox in enumerate(shot['bbox_by_camera_tlbr']):
if bbox[2] == bbox[0]: # bbox is empty, which means that this camera is missing
Expand Down Expand Up @@ -237,7 +239,6 @@ def make_meshgrids(self):
bboxes = self.labels['table']['bbox_by_camera_tlbr'][sample_idx]

if (bboxes[:, 2] - bboxes[:, 0]).min() > 0: # if == 0, then some camera is missing
print('getting sample')
sample = self.__getitem__(sample_idx)
assert len(sample['images']) == n_cameras

Expand Down Expand Up @@ -269,7 +270,6 @@ def make_meshgrids(self):
# cache (save) distortion maps
meshgrids[subject_idx, camera_idx] = cv2.convertMaps(meshgrid.reshape((h, w, 2)), None, cv2.CV_16SC2)

print('done meshgrids')
return meshgrids

def evaluate_using_per_pose_error(self, per_pose_error, split_by_subject):
Expand Down Expand Up @@ -318,8 +318,12 @@ def evaluate_by_actions(self, per_pose_error, mask=None):

return subject_scores

def evaluate(self, keypoints_3d_predicted, split_by_subject=False, transfer_cmu_to_human36m=False, transfer_human36m_to_human36m=False):
keypoints_gt = self.labels['table']['keypoints'][:, :self.num_keypoints]
def evaluate(self, keypoints_3d_predicted, indices_predicted=None, split_by_subject=False, transfer_cmu_to_human36m=False, transfer_human36m_to_human36m=False):
keypoints_gt = self.labels['table']['keypoints'][:,:self.num_keypoints]

if indices_predicted:
keypoints_gt = keypoints_gt[indices_predicted]

if keypoints_3d_predicted.shape != keypoints_gt.shape:
raise ValueError(
'`keypoints_3d_predicted` shape should be %s, got %s' % \
Expand Down
1 change: 0 additions & 1 deletion mvn/datasets/human36m_preprocessing/undistort-h36m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Usage: `python3 undistort-h36m.py <path/to/Human3.6M-root> <path/to/human36m-multiview-labels.npy> <num-processes>`
"""
import torch
import numpy as np
import cv2
from tqdm import tqdm
Expand Down
2 changes: 0 additions & 2 deletions mvn/models/triangulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,7 @@ def forward(self, images, proj_matricies, batch, minimon=None):
except RuntimeError as e:
print("Error: ", e)

print("confidences =", confidences_batch_pred)
print("proj_matricies = ", proj_matricies)
print("keypoints_2d_batch_pred =", keypoints_2d_batch_pred)
exit()

minimon.leave('triangulate')
Expand Down
69 changes: 20 additions & 49 deletions tools/scratch.ipynb

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def setup_human36m_dataloaders(config, is_train, distributed_train):
ignore_cameras=config.dataset.train.ignore_cameras if hasattr(config.dataset.train, "ignore_cameras") else [],
crop=config.dataset.train.crop if hasattr(config.dataset.train, "crop") else True,
)
train_dataset.meshgrids = train_dataset.make_meshgrids()
print(" training dataset length:", len(train_dataset))

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed_train else None
Expand Down Expand Up @@ -106,7 +105,6 @@ def setup_human36m_dataloaders(config, is_train, distributed_train):
ignore_cameras=config.dataset.val.ignore_cameras if hasattr(config.dataset.val, "ignore_cameras") else [],
crop=config.dataset.val.crop if hasattr(config.dataset.val, "crop") else True,
)
val_dataset.meshgrids = val_dataset.make_meshgrids()
print(" validation dataset length:", len(val_dataset))

val_dataloader = DataLoader(
Expand Down Expand Up @@ -256,19 +254,19 @@ def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_
) # ~ 17, 2

if False: # todo debug only
current_view = images_batch[0, 0, 0].detach().cpu().numpy() # grayscale only
current_view = images_batch[batch_i, view_i, 0].detach().cpu().numpy() # grayscale only
canvas = normalize_transformation((0, 255))(current_view)
canvas = cv2.cvtColor(canvas, cv2.COLOR_GRAY2RGB)

# draw circles where GT keypoints are
for pt in keypoints_2d_gt_proj.detach().cpu():
for pt in keypoints_2d_gt_proj.detach().cpu().numpy():
cv2.circle(
canvas, tuple(pt.astype(int)),
2, color=(0, 255, 0), thickness=3
) # green

# draw circles where predicted keypoints are
for pt in keypoints_2d_true_pred.detach().cpu():
for pt in keypoints_2d_true_pred.detach().cpu().numpy():
cv2.circle(
canvas, tuple(pt.astype(int)),
2, color=(0, 0, 255), thickness=3
Expand Down Expand Up @@ -341,8 +339,9 @@ def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_

try:
scalar_metric, full_metric = dataloader.dataset.evaluate(
results['keypoints_3d']
) # 3D MPJPE (relative to pelvis), all MPJPEs
results['keypoints_3d'],
indices_predicted=results['indexes']
) # average 3D MPJPE (relative to pelvis), all MPJPEs
except Exception as e:
print("Failed to evaluate: ", e)
scalar_metric, full_metric = 0.0, {}
Expand Down

0 comments on commit b1d9a15

Please sign in to comment.