-
Notifications
You must be signed in to change notification settings - Fork 18
/
kamiak_train_msda_upper.srun
95 lines (83 loc) · 4.21 KB
/
kamiak_train_msda_upper.srun
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/bin/bash
#SBATCH --job-name=train
#SBATCH --output=slurm_logs/train_%A_%a.out
#SBATCH --error=slurm_logs/train_%A_%a.err
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --partition=cook,free_gpu,cahnrs_gpu,kamiak
#SBATCH --time=1-00:00:00
#SBATCH --mem=20G
#SBATCH --array=0-170
#
# Upper bound separate since there's a lot of duplicates if we were to run it
# in kamiak_train_real.srun since for upper bounds the multiple combinations
# of source domains (or number of them) doesn't affect the result. Thus, just
# run three of each.
#
. kamiak_config.sh
. kamiak_tensorflow_gpu.sh
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
# Errors
handle_terminate() { echo "Exiting"; exit 1; }
handle_error() { echo "Error occurred -- exiting"; exit 1; }
trap "handle_terminate" SIGTERM SIGINT
# Get suffix, i.e. files stored in kamiak-{models,logs}-suffix
suffix=$1; shift
[[ -z $suffix ]] && { echo "no suffix specified"; handle_error; }
methods=("upper")
debugnums=(1 2 3)
# number of adaptation problems = 57
uids=(u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u5_19 u5_22 u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u5_19 u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u5_19 u5_22 u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u5_19 u5_22 u0 u3 u5_1 u5_4 u5_7 u5_10 u5_13 u5_16 u5_19 u5_22)
datasets=("ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "ucihhar" "uwave" "uwave" "uwave" "uwave" "uwave" "uwave" "uwave" "uwave" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "watch_noother" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_ar" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at" "wisdm_at")
sources=("" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "" "")
targets=("21" "4" "1" "24" "9" "8" "25" "5" "29" "18" "1" "0" "5" "2" "8" "4" "7" "6" "3" "2" "1" "6" "3" "8" "7" "5" "4" "11" "2" "1" "12" "5" "4" "10" "3" "6" "13" "7" "1" "23" "8" "32" "28" "4" "30" "3" "21" "40" "7" "1" "47" "17" "15" "14" "8" "6" "34")
# Make sure we're using the right number
correct_min=0
correct_max=$(( ${#methods[@]} * ${#debugnums[@]} * ${#sources[@]} - 1))
[[ ${#sources[@]} == ${#targets[@]} ]] || \
{ echo "source/target sizes should match"; handle_error; }
[[ ${#sources[@]} == ${#uids[@]} ]] || \
{ echo "length of sources and uids arrays differ"; handle_error; }
[[ ${#sources[@]} == ${#datasets[@]} ]] || \
{ echo "length of sources and datasets arrays differ"; handle_error; }
[[ $SLURM_ARRAY_TASK_MIN == $correct_min ]] || \
{ echo "array min should be $correct_min"; handle_error; }
[[ $SLURM_ARRAY_TASK_MAX == $correct_max ]] || \
{ echo "array max should be $correct_max"; handle_error; }
# Indexing: https://stackoverflow.com/a/34363187
index=$SLURM_ARRAY_TASK_ID
index1max=${#sources[@]}
index2max=${#debugnums[@]}
index3=$((index / (index1max * index2max)))
index=$((index - index3 * index1max * index2max))
index2=$((index / index1max))
index1=$((index % index1max))
method="${methods[$index3]}"
debugnum="${debugnums[$index2]}"
uid="${uids[$index1]}"
dataset_name="${datasets[$index1]}"
source="${sources[$index1]}"
target="${targets[$index1]}"
# Upper bound is actually "none" but without a target domain and with other args
additional_args=()
if [[ $method == "upper" ]]; then
method="none"
source="$target"
target=""
fi
# Training / model options
if [[ $method == "random" ]]; then
additional_args+=("--log_val_steps=100")
fi
echo "$suffix #$SLURM_ARRAY_TASK_ID"
echo "Method: $method"
echo "DebugNum: $debugnum"
echo "Other args: $@"
echo "UID: $uid"
echo "$dataset_name $source --> $target"
cd "$remotedir"
python3 main.py \
--logdir="$logFolder-$suffix" --modeldir="$modelFolder-$suffix" \
--method="$method" --dataset="$dataset_name" --sources="$source" \
--target="$target" --uid="$uid" --debugnum="$debugnum" \
--gpumem=0 "${additional_args[@]}" "$@" || handle_error