Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjercan committed May 10, 2021
1 parent 55be05e commit 7701a14
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 43 deletions.
17 changes: 11 additions & 6 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copy import copy

from torch.utils.data import Dataset, DataLoader
from util import load_image, load_normal
from util import load_depth, load_image, load_normal


def create_dataloader(dataset_root, json_path, batch_size=2, transform=None, workers=8, pin_memory=True, shuffle=True):
Expand Down Expand Up @@ -42,21 +42,24 @@ def __getitem__(self, index):
def __load__(self, index):
img_path = os.path.join(self.dataset_root, self.json_data[index]["image"])
normal_path = os.path.join(self.dataset_root, self.json_data[index]["normal"])
depth_path = os.path.join(self.dataset_root, self.json_data[index]["depth"])

img = load_image(img_path)
normal = load_normal(normal_path)
depth = load_depth(depth_path)

return img, normal
return img, normal, depth

def __transform__(self, data):
img, normal = data
img, normal, depth = data

if self.transform is not None:
augmentations = self.transform(image=img, normal=normal)
augmentations = self.transform(image=img, normal=normal, depth=depth)
img = augmentations["image"]
normal = augmentations["normal"]
depth = augmentations["depth"]

return img, normal
return img, normal, depth


class LoadImages():
Expand Down Expand Up @@ -139,6 +142,7 @@ def visualize(image):
],
additional_targets={
'normal': 'normal',
'depth': 'depth',
}
)

Expand All @@ -150,9 +154,10 @@ def visualize(image):
)

_, dataloader = create_dataloader("../bdataset", "test.json", transform=my_transform)
imgs, normals = next(iter(dataloader))
imgs, normals, depths = next(iter(dataloader))
assert imgs.shape == (2, 3, 256, 256), f"dataset error {imgs.shape}"
assert normals.shape == (2, 3, 256, 256), f"dataset error {normals.shape}"
assert depths.shape == (2, 1, 256, 256), f"dataset error {depths.shape}"

dataset = LoadImages(JSON, transform=img_transform)
og_img, img, path = next(iter(dataset))
Expand Down
11 changes: 6 additions & 5 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@ def __init__(self, batch_size) -> None:
self.error_sum = {}
self.error_avg = {}

def evaluate(self, predictions, normals):
error_val = evaluate_error_classification(predictions, normals)

def evaluate(self, predictions, data):
(normals, depths) = data
error_val = evaluate_error_classification(predictions, None)

self.total_size += self.batch_size
self.error_avg = avg_error(self.error_sum, error_val, self.total_size, self.batch_size)
return self.error_avg

def show(self):
return ""


def evaluate_error_classification(predictions, targets):
error = {}

return error


Expand Down
65 changes: 44 additions & 21 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,13 @@ def forward(self, x):
x = self.predict(*x)
return x


class LossFunction(nn.Module):
class ContinuityLoss(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()
self.loss = nn.CrossEntropyLoss()
super(ContinuityLoss, self).__init__()
self.y_loss = nn.L1Loss()
self.z_loss = nn.L1Loss()
self.obj_loss = nn.L1Loss()

self.loss_val = 0
self.c_loss_val= 0
self.obj_loss_val = 0

def forward(self, predictions, normals):
def forward(self, predictions):
device = predictions.device

hp_y = predictions[:, 1:, :, :] - predictions[:, 0:-1, :, :]
Expand All @@ -150,24 +143,54 @@ def forward(self, predictions, normals):
hp_y_target = torch.zeros_like(hp_y, device=device)
hp_z_target = torch.zeros_like(hp_z, device=device)

_, target = torch.max(predictions, 1)
return (self.y_loss(hp_y, hp_y_target) + self.z_loss(hp_z, hp_z_target))


class SurfaceLoss(nn.Module):
def __init__(self, eps=1e-8):
super(SurfaceLoss, self).__init__()
self.loss = nn.L1Loss()
self.eps = eps

def forward(self, predictions, normals):
_, predictions = torch.max(predictions, 1, keepdim=True)

pred_faces = (target > 0.1).float()
normal_faces = (torch.abs(normals[:, 2, :, :]) < 1e-3).float()
surfaces = (torch.abs(normals) >= self.eps)
surfaces = torch.logical_or(surfaces[:, 0:1, :, :], torch.logical_or(surfaces[:, 1:2, :, :], surfaces[:, 2:3, :, :]))
surfaces = surfaces.float()

loss = self.loss(predictions, target) * 1.0
c_loss = (self.y_loss(hp_y, hp_y_target) + self.z_loss(hp_z, hp_z_target)) * 5.0
o_loss = self.obj_loss(pred_faces, normal_faces) * 1.0
return self.loss(predictions, predictions * surfaces)


class LossFunction(nn.Module):
def __init__(self):
super(LossFunction, self).__init__()
self.c_loss = ContinuityLoss()
self.s_loss = SurfaceLoss()
self.f_loss = nn.CrossEntropyLoss()

self.c_loss_val= 0
self.s_loss_val = 0
self.f_loss_val = 0

def forward(self, predictions, data):
(normals, depths) = data

c_loss = self.c_loss(predictions) * 5.0
s_loss = self.s_loss(predictions, normals) * 1.0

_, target = torch.max(predictions, 1)
f_loss = self.f_loss(predictions, target) * 1.0

self.loss_val = loss.item()
self.c_loss_val = c_loss.item()
self.obj_loss_val = o_loss.item()
self.s_loss_val = s_loss.item()
self.f_loss_val = f_loss.item()

return loss + c_loss + o_loss
return c_loss + s_loss + f_loss

def show(self):
loss = self.loss_val + self.c_loss_val + self.obj_loss_val
return f'(total:{loss:.4f} x_entropy:{self.loss_val:.4f} cont:{self.c_loss_val:.4f} obj:{self.obj_loss_val:.4f})'
loss = self.c_loss_val + self.s_loss_val + self.f_loss_val
return f'(total:{loss:.4f} c:{self.c_loss_val:.4f} s:{self.s_loss_val:.4f} f:{self.f_loss_val:.4f})'


if __name__ == "__main__":
Expand Down
9 changes: 4 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ def run_test(model, dataloader, loss_fn, metric_fn):
loop = tqdm(dataloader, position=0, leave=True)

for _, tensors in enumerate(loop):
imgs, normals = tensors_to_device(tensors, DEVICE)
imgs, normals, depths = tensors_to_device(tensors, DEVICE)
with torch.no_grad():
imgs = imgs.to(DEVICE, non_blocking=True)

predictions = model(imgs)

loss_fn(predictions, normals)
metric_fn.evaluate(predictions, normals)
loss_fn(predictions, (normals, depths))
metric_fn.evaluate(predictions, (normals, depths))
loop.close()


Expand All @@ -46,6 +44,7 @@ def test(model=None, config=None):
],
additional_targets={
'normal': 'normal',
'depth' : 'depth',
}
)

Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def train_one_epoch(model, dataloader, loss_fn, metric_fn, solver, epoch_idx):
loop = tqdm(dataloader, position=0, leave=True)

for _, tensors in enumerate(loop):
imgs, normals = tensors_to_device(tensors, DEVICE)
imgs, normals, depths = tensors_to_device(tensors, DEVICE)

predictions = model(imgs)

loss = loss_fn(predictions, normals)
metric_fn.evaluate(predictions, normals)
loss = loss_fn(predictions, (normals, depths))
metric_fn.evaluate(predictions, (normals, depths))

model.zero_grad()
loss.backward()
Expand Down Expand Up @@ -72,6 +72,7 @@ def train(config=None, config_test=None):
],
additional_targets={
'normal': 'normal',
'depth' : 'depth',
}
)

Expand Down
6 changes: 3 additions & 3 deletions train.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
DATASET_ROOT: "../bdataset"
JSON_PATH: "train.json"
DATASET_ROOT: "../bdataset_segmentation"
JSON_PATH: "test.json"
BATCH_SIZE: 4
IMAGE_SIZE: 256
WORKERS: 8
Expand All @@ -14,7 +14,7 @@ WEIGHT_DECAY: 0.0001
MILESTONES: [50, 100, 200]
GAMMA: 0.2

NUM_EPOCHS: 300
NUM_EPOCHS: 50
TEST: False
OUT_PATH: "./runs"
LOAD_MODEL: False
Expand Down
18 changes: 18 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def load_image(path):
return img


def load_depth(path, max_depth=80):
img = exr2depth(path, maxvalue=max_depth) # 1 channel depth
assert img is not None, 'Image Not Found ' + path
return img


def load_normal(path):
img = exr2normal(path) # 3 channel normal
assert img is not None, 'Image Not Found ' + path
Expand All @@ -36,6 +42,18 @@ def img2rgb(path):
return img


def exr2depth(path, maxvalue=80):
if not os.path.isfile(path):
return None

img = cv2.imread(path, cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH)

img[img > maxvalue] = maxvalue
img = img / maxvalue

return np.array(img).astype(np.float32).reshape((img.shape[0], img.shape[1], -1))


def exr2normal(path):
if not os.path.isfile(path):
return None
Expand Down

0 comments on commit 7701a14

Please sign in to comment.