-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_albumentations.py
65 lines (47 loc) · 1.9 KB
/
my_albumentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
#
# Developed by Alex Jercan <[email protected]>
#
# References:
#
import albumentations as A
from albumentations.pytorch import ToTensorV2
class MyHorizontalFlip(A.HorizontalFlip):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply, 'normal': self.apply_normal})
def apply_normal(self, img, **params):
# when flipping horizontally the normal map should be inversed on the x axis
img[:, :, 0] = -1 * img[:, :, 0]
return super().apply(img, **params)
class MyVerticalFlip(A.VerticalFlip):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply, 'normal': self.apply_normal})
def apply_normal(self, img, **params):
img[:, :, 1] = -1 * img[:, :, 1] # y axis flip for normal maps
return super().apply(img, **params)
class MyRandomResizedCrop(A.RandomResizedCrop):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply_to_mask, 'normal': self.apply_to_mask})
class MyOpticalDistortion(A.OpticalDistortion):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply_to_mask, 'normal': self.apply_to_mask})
class MyGridDistortion(A.GridDistortion):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply_to_mask, 'normal': self.apply_to_mask})
class MyLongestMaxSize(A.LongestMaxSize):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply_to_mask, 'normal': self.apply_to_mask})
class MyPadIfNeeded(A.PadIfNeeded):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply_to_mask, 'normal': self.apply_to_mask})
class MyToTensorV2(ToTensorV2):
@property
def targets(self):
return dict(super().targets, **{'depth': self.apply, 'normal': self.apply})