-
Notifications
You must be signed in to change notification settings - Fork 0
/
processMiniBatch.m
22 lines (19 loc) · 1.03 KB
/
processMiniBatch.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
function [imagesBatch, inputIDs, attentionMask, segmentIDs] = processMiniBatch(images, tokenisedCaption, ~)
% NB: Hard code the `PaddingValue` for now
% [~, tokenizer] = bert("Model","tiny");
% paddingValue = tokenizer.PaddingCode;
paddingValue = 1; % Hard coded bert tokeniser padding code
[inputIDs, attentionMask] = padsequences(tokenisedCaption, 2, "PaddingValue", paddingValue); % Returns in CTB format
inputIDs = permute(inputIDs, [1 3 2]);
attentionMask = permute(attentionMask, [1 3 2]);
segmentIDs = ones(size(inputIDs)); % The `segmentIDs` are always 1, constraint imposed by the `bert` language model
% TODO: Move image resizing outside into a transform datastore
imagesBatch = cellfun(@(x) imresize(x, [227 227]), images, UniformOutput=false);
imagesBatch = cat(4, imagesBatch{:});
if canUseGPU
imagesBatch = gpuArray(imagesBatch);
inputIDs = gpuArray(inputIDs);
attentionMask = gpuArray(attentionMask);
segmentIDs = gpuArray(segmentIDs);
end
end