mirror of
https://github.com/ION606/COGMOD-HWI.git
synced 2026-05-14 22:16:57 +00:00
added base project
This commit is contained in:
@@ -0,0 +1 @@
|
||||
data/
|
||||
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
all: run analyze doc
|
||||
|
||||
run:
|
||||
python main.py
|
||||
|
||||
analyze:
|
||||
python main.py --analyze
|
||||
|
||||
doc:
|
||||
pdflatex report.tex
|
||||
pdflatex report.tex
|
||||
pdflatex report.tex
|
||||
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7088,"{'0.1': 0.6319, '0.2': 0.4336, '0.3': 0.2913}"
|
||||
42,sgd,standard,0.6859,"{'0.1': 0.5952, '0.2': 0.4019, '0.3': 0.2757}"
|
||||
42,sgd,aggressive,0.6536,"{'0.1': 0.5778, '0.2': 0.43, '0.3': 0.2943}"
|
||||
42,adam,none,0.5451,"{'0.1': 0.4221, '0.2': 0.2298, '0.3': 0.1545}"
|
||||
42,adam,standard,0.5101,"{'0.1': 0.454, '0.2': 0.2098, '0.3': 0.1324}"
|
||||
42,adam,aggressive,0.4427,"{'0.1': 0.4048, '0.2': 0.2461, '0.3': 0.1547}"
|
||||
123,sgd,none,0.6974,"{'0.1': 0.63, '0.2': 0.4452, '0.3': 0.312}"
|
||||
123,sgd,standard,0.6674,"{'0.1': 0.6252, '0.2': 0.4146, '0.3': 0.2764}"
|
||||
123,sgd,aggressive,0.6691,"{'0.1': 0.6179, '0.2': 0.4691, '0.3': 0.3423}"
|
||||
123,adam,none,0.6049,"{'0.1': 0.4685, '0.2': 0.3387, '0.3': 0.2378}"
|
||||
123,adam,standard,0.4654,"{'0.1': 0.4071, '0.2': 0.3073, '0.3': 0.2341}"
|
||||
123,adam,aggressive,0.5096,"{'0.1': 0.4624, '0.2': 0.3219, '0.3': 0.2159}"
|
||||
999,sgd,none,0.7058,"{'0.1': 0.6252, '0.2': 0.3848, '0.3': 0.2276}"
|
||||
999,sgd,standard,0.6861,"{'0.1': 0.6002, '0.2': 0.4184, '0.3': 0.2986}"
|
||||
999,sgd,aggressive,0.6595,"{'0.1': 0.5775, '0.2': 0.4165, '0.3': 0.2899}"
|
||||
999,adam,none,0.5573,"{'0.1': 0.4562, '0.2': 0.293, '0.3': 0.2167}"
|
||||
999,adam,standard,0.4835,"{'0.1': 0.4136, '0.2': 0.2221, '0.3': 0.1548}"
|
||||
999,adam,aggressive,0.5123,"{'0.1': 0.449, '0.2': 0.2571, '0.3': 0.1658}"
|
||||
|
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7088,"{'0.1': 0.6319, '0.2': 0.4336, '0.3': 0.2913}"
|
||||
42,sgd,standard,0.6859,"{'0.1': 0.5952, '0.2': 0.4019, '0.3': 0.2757}"
|
||||
42,sgd,aggressive,0.6536,"{'0.1': 0.5778, '0.2': 0.43, '0.3': 0.2943}"
|
||||
42,adam,none,0.5451,"{'0.1': 0.4221, '0.2': 0.2298, '0.3': 0.1545}"
|
||||
42,adam,standard,0.5101,"{'0.1': 0.454, '0.2': 0.2098, '0.3': 0.1324}"
|
||||
42,adam,aggressive,0.4427,"{'0.1': 0.4048, '0.2': 0.2461, '0.3': 0.1547}"
|
||||
123,sgd,none,0.6974,"{'0.1': 0.63, '0.2': 0.4452, '0.3': 0.312}"
|
||||
123,sgd,standard,0.6674,"{'0.1': 0.6252, '0.2': 0.4146, '0.3': 0.2764}"
|
||||
123,sgd,aggressive,0.6691,"{'0.1': 0.6179, '0.2': 0.4691, '0.3': 0.3423}"
|
||||
123,adam,none,0.6049,"{'0.1': 0.4685, '0.2': 0.3387, '0.3': 0.2378}"
|
||||
123,adam,standard,0.4654,"{'0.1': 0.4071, '0.2': 0.3073, '0.3': 0.2341}"
|
||||
123,adam,aggressive,0.5096,"{'0.1': 0.4624, '0.2': 0.3219, '0.3': 0.2159}"
|
||||
999,sgd,none,0.7058,"{'0.1': 0.6252, '0.2': 0.3848, '0.3': 0.2276}"
|
||||
999,sgd,standard,0.6861,"{'0.1': 0.6002, '0.2': 0.4184, '0.3': 0.2986}"
|
||||
999,sgd,aggressive,0.6595,"{'0.1': 0.5775, '0.2': 0.4165, '0.3': 0.2899}"
|
||||
999,adam,none,0.5573,"{'0.1': 0.4562, '0.2': 0.293, '0.3': 0.2167}"
|
||||
999,adam,standard,0.4835,"{'0.1': 0.4136, '0.2': 0.2221, '0.3': 0.1548}"
|
||||
999,adam,aggressive,0.5123,"{'0.1': 0.449, '0.2': 0.2571, '0.3': 0.1658}"
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8846732257843017,0.3103,1.5956477916717529,0.4208
|
||||
2,1.6917855532073975,0.38452,1.5259598985671996,0.439
|
||||
3,1.6328186359024048,0.40266,1.5264780994415283,0.4508
|
||||
4,1.5851124533462524,0.42526,1.424309602355957,0.4878
|
||||
5,1.5511348150634765,0.43632,1.4553828117370606,0.4854
|
||||
6,1.5411012967681885,0.44224,1.5167289329528808,0.4561
|
||||
7,1.5281506721115112,0.4473,1.3984178335189819,0.5015
|
||||
8,1.5113245751190185,0.4561,1.3988324977874755,0.5007
|
||||
9,1.5111655992889403,0.45554,1.4732477207183838,0.4733
|
||||
10,1.498042034225464,0.46026,1.3725576107025146,0.5096
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0559358141326904,0.24472,1.828050934791565,0.3374
|
||||
2,1.857869116783142,0.3182,1.6779026025772095,0.3909
|
||||
3,1.7681750354385375,0.35418,1.6347509387969972,0.4094
|
||||
4,1.7275719815826416,0.37254,1.6000082025527953,0.4225
|
||||
5,1.699682864151001,0.37816,1.571951426887512,0.4362
|
||||
6,1.6953852003860475,0.38134,1.582085866355896,0.4284
|
||||
7,1.6721481609344482,0.392,1.584311390686035,0.4267
|
||||
8,1.6576726916885376,0.39738,1.5268918621063232,0.443
|
||||
9,1.6475247369384765,0.402,1.5393544744491576,0.4441
|
||||
10,1.6505113174819945,0.39736,1.5213500274658203,0.4427
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9911173669052125,0.27896,1.6828551885604859,0.3994
|
||||
2,1.6908308898162843,0.38222,1.6040934282302857,0.4364
|
||||
3,1.6209180907440186,0.41298,1.4993294555664063,0.459
|
||||
4,1.591390087928772,0.42586,1.5083149290084839,0.4616
|
||||
5,1.5589105658340454,0.43662,1.428050018119812,0.4809
|
||||
6,1.540269641456604,0.44422,1.4299810447692871,0.4858
|
||||
7,1.5332854122543336,0.44906,1.3867164831161498,0.5119
|
||||
8,1.5187089794921875,0.4558,1.3682811828613282,0.5136
|
||||
9,1.5207957489013673,0.45552,1.4490770374298096,0.4736
|
||||
10,1.504308384437561,0.4596,1.3697486953735352,0.5123
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.6660349139404298,0.39084,1.3631993772506714,0.5092
|
||||
2,1.306575980796814,0.53194,1.2523228994369506,0.5514
|
||||
3,1.189557453918457,0.57828,1.2570601411819458,0.5594
|
||||
4,1.122563935699463,0.60502,1.187407196044922,0.5902
|
||||
5,1.0689283304977417,0.6233,1.2376863445281983,0.5659
|
||||
6,1.0334041289138793,0.6345,1.170531530189514,0.592
|
||||
7,0.9959703926849365,0.64648,1.1732149269104004,0.5901
|
||||
8,0.9762340403366089,0.65428,1.185727474975586,0.5919
|
||||
9,0.962807445487976,0.65934,1.1849331123352052,0.5937
|
||||
10,0.9365197282409667,0.66622,1.1537712555885316,0.6049
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.7286399437713622,0.37382,1.4246170207977296,0.4816
|
||||
2,1.3806899689483643,0.50334,1.3549695669174195,0.5164
|
||||
3,1.3021614087677003,0.53534,1.332729033279419,0.5202
|
||||
4,1.2415979095458984,0.55796,1.2883787483215332,0.542
|
||||
5,1.1996409232330323,0.5716,1.2885587697982788,0.5404
|
||||
6,1.1598682831573486,0.58466,1.2654446369171142,0.5521
|
||||
7,1.1435769744873048,0.59348,1.2620088846206665,0.5593
|
||||
8,1.122702364425659,0.60254,1.2513995946884156,0.5613
|
||||
9,1.0767982495117188,0.6175,1.2432810546875,0.5644
|
||||
10,1.0651884030914307,0.6178,1.321402802658081,0.5451
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.7729383182525635,0.35124,1.5228908493041993,0.4427
|
||||
2,1.430595055770874,0.482,1.38000087184906,0.5018
|
||||
3,1.32311696308136,0.52478,1.313341820716858,0.5259
|
||||
4,1.2678979627990723,0.54632,1.2857267393112182,0.5393
|
||||
5,1.2276857889556885,0.56124,1.299703761291504,0.5393
|
||||
6,1.1998348657989502,0.57262,1.2608893209457397,0.5565
|
||||
7,1.1683492805480957,0.58382,1.27673601436615,0.5546
|
||||
8,1.157327294807434,0.5878,1.2376579500198364,0.5654
|
||||
9,1.131808783454895,0.59764,1.2557779052734375,0.5572
|
||||
10,1.123384162902832,0.60062,1.2491890354156494,0.5573
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8813365224456786,0.30172,1.6416668697357177,0.3985
|
||||
2,1.6806913624954223,0.37672,1.5487596210479737,0.4274
|
||||
3,1.6351341115570068,0.40124,1.5625509996414184,0.4249
|
||||
4,1.6064803477859497,0.41056,1.511009596824646,0.4455
|
||||
5,1.5825084410476684,0.42064,1.5225824378967285,0.4431
|
||||
6,1.5676781223678589,0.42982,1.4784889535903931,0.4585
|
||||
7,1.5567302836608887,0.43112,1.4934770944595337,0.4557
|
||||
8,1.5449237117004395,0.43588,1.463665308380127,0.4715
|
||||
9,1.5365594310760498,0.44164,1.4671966415405273,0.4741
|
||||
10,1.5271776266098023,0.44406,1.4860247068405152,0.4654
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9121077362060548,0.28534,1.6980935123443603,0.3717
|
||||
2,1.634671248512268,0.3994,1.5801765157699585,0.4126
|
||||
3,1.5572229358291625,0.43102,1.4792659259796141,0.4583
|
||||
4,1.5200784818267823,0.44598,1.4566459318161011,0.4732
|
||||
5,1.4923802797317505,0.4583,1.3978405960083007,0.4928
|
||||
6,1.4686766564178466,0.46806,1.3776366807937621,0.503
|
||||
7,1.4700688259124757,0.46884,1.3533947219848632,0.5163
|
||||
8,1.4470393926620484,0.4773,1.345004845237732,0.5178
|
||||
9,1.4390167654800414,0.47734,1.3564238325119018,0.5147
|
||||
10,1.4347680015182496,0.48004,1.3694934066772462,0.5101
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.8785975454711914,0.30716,1.543831322669983,0.4375
|
||||
2,1.6268331103515625,0.40026,1.5018853443145752,0.4532
|
||||
3,1.5749516596221924,0.42014,1.5034992160797118,0.452
|
||||
4,1.5602307732009888,0.42962,1.4366535945892334,0.477
|
||||
5,1.5448442478561402,0.43214,1.439453507232666,0.4691
|
||||
6,1.5286526789474488,0.43836,1.4527703874588012,0.4671
|
||||
7,1.5160772939300537,0.4433,1.4506785593032836,0.4667
|
||||
8,1.5187876448822022,0.44144,1.4333687414169312,0.4782
|
||||
9,1.5096457154083252,0.44728,1.4133702894210816,0.4894
|
||||
10,1.495245510482788,0.45166,1.411236897277832,0.4835
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0651948305892947,0.24042,1.729777430343628,0.3836
|
||||
2,1.7303296995544433,0.37548,1.49167406539917,0.4667
|
||||
3,1.5553171773147583,0.4367,1.3456432037353516,0.5205
|
||||
4,1.4635125664901734,0.47256,1.2475262697219849,0.5611
|
||||
5,1.3722680527114868,0.50766,1.1914715614318847,0.5746
|
||||
6,1.3082861290740966,0.53284,1.121172977924347,0.5981
|
||||
7,1.2505797283554076,0.5567,1.1289072477340698,0.5922
|
||||
8,1.2003705361175536,0.57516,1.0124303802490235,0.6468
|
||||
9,1.1683612518310547,0.58702,0.9980959354400635,0.65
|
||||
10,1.1325882072067261,0.6023,0.9680170437812805,0.6691
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.0705010092163088,0.22946,1.7674200216293334,0.3802
|
||||
2,1.7432165607452392,0.36844,1.533634359550476,0.4416
|
||||
3,1.5778629602050782,0.42886,1.3606201539993286,0.5097
|
||||
4,1.4700220097351073,0.46802,1.2655291788101197,0.5451
|
||||
5,1.3893862116241456,0.50128,1.204566476535797,0.5692
|
||||
6,1.3244363966751098,0.5255,1.2020204118728637,0.5744
|
||||
7,1.2688888247299195,0.54674,1.0509296971321105,0.6334
|
||||
8,1.2191224797821045,0.56592,1.0520767385482788,0.6286
|
||||
9,1.1705705532073976,0.58452,1.0169856172561647,0.6394
|
||||
10,1.1425927544021606,0.59578,0.9944705787658692,0.6536
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.028081403198242,0.25312,1.7972938034057617,0.3576
|
||||
2,1.6919663692855835,0.38654,1.4519519706726074,0.4827
|
||||
3,1.5530475779342652,0.43992,1.3310353057861328,0.5251
|
||||
4,1.4540857498931885,0.47812,1.292456304550171,0.5413
|
||||
5,1.3654178101348877,0.51218,1.1448781085014343,0.6
|
||||
6,1.3064117618942261,0.53414,1.1193419974327088,0.6077
|
||||
7,1.2500248126220703,0.5551,1.0644649827957153,0.6217
|
||||
8,1.2079097275543214,0.5713,1.0305480089187622,0.6358
|
||||
9,1.1659785708236694,0.58452,0.9741595977783203,0.6535
|
||||
10,1.1451263586425782,0.59474,0.9785669965744018,0.6595
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9007139762496947,0.30904,1.5374012899398803,0.4448
|
||||
2,1.4438594173431396,0.48224,1.360331530570984,0.5054
|
||||
3,1.2486270139312745,0.5548,1.1819912509918213,0.5701
|
||||
4,1.1134616537475586,0.60458,1.0598442848205567,0.6256
|
||||
5,0.9943869771957398,0.6509,0.9994649070739746,0.6511
|
||||
6,0.9061599386596679,0.68082,0.9807863761901855,0.6526
|
||||
7,0.832235673122406,0.70842,0.9603629438400269,0.6641
|
||||
8,0.7533782648849487,0.7388,0.9113856033325195,0.6857
|
||||
9,0.6814181573486328,0.76166,0.9031674418449401,0.6972
|
||||
10,0.6241655448532104,0.78184,0.8935547742843628,0.6974
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9639052035522462,0.28114,1.6098346630096436,0.4097
|
||||
2,1.4680433963775634,0.46986,1.302504965209961,0.5284
|
||||
3,1.2549403902435303,0.5519,1.225121039199829,0.5624
|
||||
4,1.1094685572624206,0.60738,1.0919673002243042,0.612
|
||||
5,0.9939155144119263,0.6497,0.9837436180114746,0.659
|
||||
6,0.8942083934020996,0.6852,0.957837422657013,0.6612
|
||||
7,0.8110888798904419,0.71492,0.9374553295135498,0.6761
|
||||
8,0.744535570487976,0.7387,0.9072779047966003,0.6925
|
||||
9,0.6828387829208374,0.76086,0.8724043285369874,0.7033
|
||||
10,0.6119843885612488,0.78486,0.8666944108009338,0.7088
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.9422726140975952,0.28742,1.637837126159668,0.4084
|
||||
2,1.4907189490509034,0.46088,1.3707801050186157,0.5124
|
||||
3,1.2916603052902222,0.5373,1.1907301538467407,0.5763
|
||||
4,1.1448502359771728,0.59362,1.0813807232856751,0.6132
|
||||
5,1.0147696270942688,0.64488,0.9749312539100647,0.658
|
||||
6,0.9084624612808228,0.68092,0.971125221824646,0.6546
|
||||
7,0.8264397340011597,0.70914,0.9333860067367554,0.6724
|
||||
8,0.7552365953445435,0.73546,0.9200770247459411,0.6771
|
||||
9,0.6857615605545044,0.76032,0.8947711204528809,0.6944
|
||||
10,0.6202431632995605,0.78266,0.8700129133224487,0.7058
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.029623507652283,0.2529,1.688811130142212,0.3927
|
||||
2,1.6488482897186278,0.40018,1.4374095891952514,0.4752
|
||||
3,1.4772799210739136,0.46326,1.3251022443771363,0.5178
|
||||
4,1.3625635675048828,0.50938,1.2472985363006592,0.5517
|
||||
5,1.2698390966033934,0.54114,1.1421482133865357,0.5942
|
||||
6,1.1823028423309325,0.57602,1.0808060108184814,0.6158
|
||||
7,1.118297325630188,0.60142,0.9704109483718872,0.6591
|
||||
8,1.0673625009155274,0.62172,0.9435201133728027,0.6721
|
||||
9,1.0345378282928466,0.6332,0.9358344539642334,0.6687
|
||||
10,0.9860974870300293,0.65256,0.950194506072998,0.6674
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,2.009225925979614,0.26266,1.6814233026504517,0.3895
|
||||
2,1.6603656655120849,0.39688,1.4227338985443114,0.4802
|
||||
3,1.4855073391342164,0.45954,1.3364979927062988,0.5176
|
||||
4,1.3897099987411499,0.49994,1.2245563619613646,0.562
|
||||
5,1.2903345095443726,0.5355,1.2043679784774781,0.5667
|
||||
6,1.2084671304321288,0.57048,1.0968135701179504,0.6135
|
||||
7,1.1449422802734375,0.5905,1.0201958515167235,0.6367
|
||||
8,1.0922849000167847,0.61064,0.9836806530952453,0.6468
|
||||
9,1.0471580500793456,0.62822,0.9301959774017334,0.6716
|
||||
10,1.0136299449157715,0.64068,0.9022101957321167,0.6859
|
||||
|
@@ -0,0 +1,11 @@
|
||||
epoch,train_loss,train_acc,test_loss,test_acc
|
||||
1,1.959310048789978,0.28234,1.5769516773223877,0.4368
|
||||
2,1.5942330165863037,0.42066,1.4556853477478027,0.4914
|
||||
3,1.443750505142212,0.47874,1.316703634262085,0.5271
|
||||
4,1.3386957333755494,0.51598,1.1774909160614013,0.5831
|
||||
5,1.2375807570266724,0.55794,1.0962469652175904,0.6125
|
||||
6,1.164957059440613,0.58492,1.0318147755622864,0.6338
|
||||
7,1.102748593158722,0.60784,0.9843499254226684,0.6612
|
||||
8,1.0541843451309205,0.62636,0.9331276074409485,0.6778
|
||||
9,1.0117439514541626,0.64112,0.8982747142791748,0.6885
|
||||
10,0.982185898399353,0.65222,0.9087156764984131,0.6861
|
||||
|
@@ -0,0 +1,7 @@
|
||||
from torchvision.datasets import CIFAR10, CIFAR100
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
ds10 = CIFAR10(root='data/', train=True, download=True)
|
||||
ds100 = CIFAR100(root='data/', train=True, download=True)
|
||||
|
||||
ds_c10c = tfds.load('cifar10_corrupted')
|
||||
+257
@@ -0,0 +1,257 @@
|
||||
import os
|
||||
import argparse
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import statsmodels.api as sm
|
||||
from statsmodels.formula.api import ols
|
||||
|
||||
# simple cnn model definition
|
||||
# I looked a lot at https://github.com/giusarno/SimpleCNN/blob/master/examples/cifar10/themodel.py
|
||||
# before making this class, mostly because I was not aware of the `MaxPool2d` function
|
||||
|
||||
|
||||
class SimpleCNN(nn.Module):
|
||||
def __init__(self, num_classes=10):
|
||||
super(SimpleCNN, self).__init__()
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(64 * 8 * 8, 128), nn.ReLU(),
|
||||
nn.Linear(128, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_data_loaders(batch_size, augmentation):
|
||||
# transform pipelines
|
||||
if augmentation == 'none':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
elif augmentation == 'standard':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
elif augmentation == 'aggressive':
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(15),
|
||||
transforms.RandomCrop(32, padding=4),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2,
|
||||
saturation=0.2, hue=0.1),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
else:
|
||||
raise ValueError(f"unknown augmentation: {augmentation}")
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
train_dataset = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=True, download=True, transform=transform_train)
|
||||
test_dataset = torchvision.datasets.CIFAR10(
|
||||
root='./data', train=False, download=True, transform=transform_test)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
||||
test_loader = DataLoader(
|
||||
test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
||||
|
||||
return train_loader, test_loader
|
||||
|
||||
|
||||
# train for 1 epoch
|
||||
|
||||
|
||||
def train_one_epoch(model, optimizer, criterion, dataloader, device):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
for inputs, targets in dataloader:
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
epoch_loss = running_loss / total
|
||||
epoch_acc = correct / total
|
||||
return epoch_loss, epoch_acc
|
||||
|
||||
|
||||
# eval on clean data
|
||||
|
||||
|
||||
def evaluate(model, criterion, dataloader, device):
|
||||
model.eval()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in dataloader:
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
loss = running_loss / total
|
||||
acc = correct / total
|
||||
return loss, acc
|
||||
|
||||
|
||||
# eval robustness under gaussian noise
|
||||
|
||||
|
||||
def evaluate_robustness(model, dataloader, device, noise_std):
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in dataloader:
|
||||
noisy_inputs = inputs + noise_std * torch.randn_like(inputs)
|
||||
noisy_inputs = torch.clamp(noisy_inputs, 0.0, 1.0)
|
||||
noisy_inputs, targets = noisy_inputs.to(device), targets.to(device)
|
||||
outputs = model(noisy_inputs)
|
||||
_, predicted = outputs.max(1)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
total += targets.size(0)
|
||||
acc = correct / total
|
||||
return acc
|
||||
|
||||
|
||||
def analyze_results(results_path='results.json'):
|
||||
import json
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from statsmodels.formula.api import ols
|
||||
import statsmodels.api as sm
|
||||
|
||||
with open(results_path) as f:
|
||||
results = json.load(f)
|
||||
df = pd.DataFrame(results)
|
||||
df.to_csv('analysis_results.csv', index=False)
|
||||
|
||||
# full ANOVA w/interaction
|
||||
model = ols('test_acc ~ C(optimizer) * C(augmentation)', data=df).fit()
|
||||
anova_table = sm.stats.anova_lm(model, typ=2)
|
||||
print('anova on test accuracy:')
|
||||
print(anova_table)
|
||||
|
||||
# composite label
|
||||
df['condition'] = df['optimizer'] + '_' + df['augmentation']
|
||||
df.plot.bar(x='condition', y='test_acc', rot=45)
|
||||
plt.ylabel('test accuracy')
|
||||
plt.tight_layout()
|
||||
plt.savefig('test_acc_comparison.png')
|
||||
print('saved plot to test_acc_comparison.png')
|
||||
|
||||
|
||||
# main (PUBLIC STATIC VOID AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA)
|
||||
def run_experiments(args):
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
results = []
|
||||
|
||||
optimizers = {
|
||||
'sgd': lambda params: optim.SGD(params, lr=args.lr, momentum=0.9),
|
||||
'adam': lambda params: optim.Adam(params, lr=args.lr)
|
||||
}
|
||||
augmentations = ['none', 'standard', 'aggressive']
|
||||
|
||||
seeds = [42, 123, 999]
|
||||
for seed in seeds:
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
for opt_name in optimizers:
|
||||
for aug in augmentations:
|
||||
train_loader, test_loader = get_data_loaders(args.batch_size, aug)
|
||||
model = SimpleCNN(num_classes=10).to(device)
|
||||
optimizer = optimizers[opt_name](model.parameters())
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
history = {
|
||||
'epoch': [], 'train_loss': [], 'train_acc': [],
|
||||
'test_loss': [], 'test_acc': []
|
||||
}
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_loss, train_acc = train_one_epoch(
|
||||
model, optimizer, criterion, train_loader, device)
|
||||
test_loss, test_acc = evaluate(
|
||||
model, criterion, test_loader, device)
|
||||
|
||||
history['epoch'].append(epoch + 1)
|
||||
history['train_loss'].append(train_loss)
|
||||
history['train_acc'].append(train_acc)
|
||||
history['test_loss'].append(test_loss)
|
||||
history['test_acc'].append(test_acc)
|
||||
|
||||
print(f"[{opt_name}][{aug}][epoch {epoch + 1}] "
|
||||
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
|
||||
f"test_acc={test_acc:.4f}")
|
||||
|
||||
noise_levels = [0.1, 0.2, 0.3]
|
||||
robustness = {noise: evaluate_robustness(
|
||||
model, test_loader, device, noise) for noise in noise_levels}
|
||||
|
||||
pd.DataFrame(history).to_csv(
|
||||
f"history_{opt_name}_{aug}_{seed}.csv", index=False)
|
||||
results.append({
|
||||
'seed': seed,
|
||||
'optimizer': opt_name,
|
||||
'augmentation': aug,
|
||||
'test_acc': test_acc,
|
||||
'robustness': robustness
|
||||
})
|
||||
|
||||
with open('results.json', 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print('saved results to results.json')
|
||||
|
||||
|
||||
# credit: I gave chatgpt a list of args and it made the arg parser for me
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=128)
|
||||
parser.add_argument('--lr', type=float, default=0.01)
|
||||
parser.add_argument('--epochs', type=int, default=10)
|
||||
parser.add_argument('--analyze', action='store_true',
|
||||
help='run analysis on results')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.analyze:
|
||||
analyze_results()
|
||||
else:
|
||||
run_experiments(args)
|
||||
@@ -0,0 +1,19 @@
|
||||
seed,optimizer,augmentation,test_acc,robustness
|
||||
42,sgd,none,0.7088,"{'0.1': 0.6319, '0.2': 0.4336, '0.3': 0.2913}"
|
||||
42,sgd,standard,0.6859,"{'0.1': 0.5952, '0.2': 0.4019, '0.3': 0.2757}"
|
||||
42,sgd,aggressive,0.6536,"{'0.1': 0.5778, '0.2': 0.43, '0.3': 0.2943}"
|
||||
42,adam,none,0.5451,"{'0.1': 0.4221, '0.2': 0.2298, '0.3': 0.1545}"
|
||||
42,adam,standard,0.5101,"{'0.1': 0.454, '0.2': 0.2098, '0.3': 0.1324}"
|
||||
42,adam,aggressive,0.4427,"{'0.1': 0.4048, '0.2': 0.2461, '0.3': 0.1547}"
|
||||
123,sgd,none,0.6974,"{'0.1': 0.63, '0.2': 0.4452, '0.3': 0.312}"
|
||||
123,sgd,standard,0.6674,"{'0.1': 0.6252, '0.2': 0.4146, '0.3': 0.2764}"
|
||||
123,sgd,aggressive,0.6691,"{'0.1': 0.6179, '0.2': 0.4691, '0.3': 0.3423}"
|
||||
123,adam,none,0.6049,"{'0.1': 0.4685, '0.2': 0.3387, '0.3': 0.2378}"
|
||||
123,adam,standard,0.4654,"{'0.1': 0.4071, '0.2': 0.3073, '0.3': 0.2341}"
|
||||
123,adam,aggressive,0.5096,"{'0.1': 0.4624, '0.2': 0.3219, '0.3': 0.2159}"
|
||||
999,sgd,none,0.7058,"{'0.1': 0.6252, '0.2': 0.3848, '0.3': 0.2276}"
|
||||
999,sgd,standard,0.6861,"{'0.1': 0.6002, '0.2': 0.4184, '0.3': 0.2986}"
|
||||
999,sgd,aggressive,0.6595,"{'0.1': 0.5775, '0.2': 0.4165, '0.3': 0.2899}"
|
||||
999,adam,none,0.5573,"{'0.1': 0.4562, '0.2': 0.293, '0.3': 0.2167}"
|
||||
999,adam,standard,0.4835,"{'0.1': 0.4136, '0.2': 0.2221, '0.3': 0.1548}"
|
||||
999,adam,aggressive,0.5123,"{'0.1': 0.449, '0.2': 0.2571, '0.3': 0.1658}"
|
||||
|
@@ -0,0 +1,60 @@
|
||||
[sgd][none][epoch 0] train_loss=1.9493, train_acc=0.2892, test_acc=0.4237
|
||||
[sgd][none][epoch 1] train_loss=1.4806, train_acc=0.4651, test_acc=0.5235
|
||||
[sgd][none][epoch 2] train_loss=1.2920, train_acc=0.5367, test_acc=0.5619
|
||||
[sgd][none][epoch 3] train_loss=1.1722, train_acc=0.5809, test_acc=0.6119
|
||||
[sgd][none][epoch 4] train_loss=1.0507, train_acc=0.6286, test_acc=0.6334
|
||||
[sgd][none][epoch 5] train_loss=0.9572, train_acc=0.6634, test_acc=0.6452
|
||||
[sgd][none][epoch 6] train_loss=0.8812, train_acc=0.6916, test_acc=0.6703
|
||||
[sgd][none][epoch 7] train_loss=0.7986, train_acc=0.7200, test_acc=0.6722
|
||||
[sgd][none][epoch 8] train_loss=0.7448, train_acc=0.7412, test_acc=0.6824
|
||||
[sgd][none][epoch 9] train_loss=0.6798, train_acc=0.7641, test_acc=0.6843
|
||||
[sgd][standard][epoch 0] train_loss=2.0006, train_acc=0.2638, test_acc=0.4103
|
||||
[sgd][standard][epoch 1] train_loss=1.6251, train_acc=0.4074, test_acc=0.4890
|
||||
[sgd][standard][epoch 2] train_loss=1.4750, train_acc=0.4641, test_acc=0.5426
|
||||
[sgd][standard][epoch 3] train_loss=1.3654, train_acc=0.5054, test_acc=0.5678
|
||||
[sgd][standard][epoch 4] train_loss=1.2646, train_acc=0.5472, test_acc=0.6111
|
||||
[sgd][standard][epoch 5] train_loss=1.1843, train_acc=0.5760, test_acc=0.6166
|
||||
[sgd][standard][epoch 6] train_loss=1.1222, train_acc=0.5997, test_acc=0.6571
|
||||
[sgd][standard][epoch 7] train_loss=1.0737, train_acc=0.6188, test_acc=0.6665
|
||||
[sgd][standard][epoch 8] train_loss=1.0308, train_acc=0.6354, test_acc=0.6872
|
||||
[sgd][standard][epoch 9] train_loss=0.9978, train_acc=0.6465, test_acc=0.6807
|
||||
[sgd][aggressive][epoch 0] train_loss=2.0480, train_acc=0.2484, test_acc=0.3840
|
||||
[sgd][aggressive][epoch 1] train_loss=1.7163, train_acc=0.3802, test_acc=0.4501
|
||||
[sgd][aggressive][epoch 2] train_loss=1.5662, train_acc=0.4333, test_acc=0.5033
|
||||
[sgd][aggressive][epoch 3] train_loss=1.4807, train_acc=0.4668, test_acc=0.5330
|
||||
[sgd][aggressive][epoch 4] train_loss=1.4095, train_acc=0.4943, test_acc=0.5762
|
||||
[sgd][aggressive][epoch 5] train_loss=1.3395, train_acc=0.5195, test_acc=0.5879
|
||||
[sgd][aggressive][epoch 6] train_loss=1.2735, train_acc=0.5444, test_acc=0.6154
|
||||
[sgd][aggressive][epoch 7] train_loss=1.2203, train_acc=0.5677, test_acc=0.6368
|
||||
[sgd][aggressive][epoch 8] train_loss=1.1891, train_acc=0.5792, test_acc=0.6300
|
||||
[sgd][aggressive][epoch 9] train_loss=1.1479, train_acc=0.5907, test_acc=0.6630
|
||||
[adam][none][epoch 0] train_loss=1.7509, train_acc=0.3614, test_acc=0.4276
|
||||
[adam][none][epoch 1] train_loss=1.4346, train_acc=0.4818, test_acc=0.4860
|
||||
[adam][none][epoch 2] train_loss=1.3425, train_acc=0.5193, test_acc=0.5122
|
||||
[adam][none][epoch 3] train_loss=1.2968, train_acc=0.5353, test_acc=0.5197
|
||||
[adam][none][epoch 4] train_loss=1.2610, train_acc=0.5499, test_acc=0.5428
|
||||
[adam][none][epoch 5] train_loss=1.2298, train_acc=0.5618, test_acc=0.5206
|
||||
[adam][none][epoch 6] train_loss=1.2102, train_acc=0.5682, test_acc=0.5455
|
||||
[adam][none][epoch 7] train_loss=1.1824, train_acc=0.5800, test_acc=0.5495
|
||||
[adam][none][epoch 8] train_loss=1.1591, train_acc=0.5886, test_acc=0.5656
|
||||
[adam][none][epoch 9] train_loss=1.1332, train_acc=0.5972, test_acc=0.5696
|
||||
[adam][standard][epoch 0] train_loss=1.9005, train_acc=0.3018, test_acc=0.4193
|
||||
[adam][standard][epoch 1] train_loss=1.6180, train_acc=0.4022, test_acc=0.4547
|
||||
[adam][standard][epoch 2] train_loss=1.5576, train_acc=0.4308, test_acc=0.4751
|
||||
[adam][standard][epoch 3] train_loss=1.5089, train_acc=0.4519, test_acc=0.4908
|
||||
[adam][standard][epoch 4] train_loss=1.4817, train_acc=0.4578, test_acc=0.4807
|
||||
[adam][standard][epoch 5] train_loss=1.4661, train_acc=0.4690, test_acc=0.4925
|
||||
[adam][standard][epoch 6] train_loss=1.4498, train_acc=0.4750, test_acc=0.5123
|
||||
[adam][standard][epoch 7] train_loss=1.4318, train_acc=0.4831, test_acc=0.4820
|
||||
[adam][standard][epoch 8] train_loss=1.4296, train_acc=0.4812, test_acc=0.5210
|
||||
[adam][standard][epoch 9] train_loss=1.4231, train_acc=0.4860, test_acc=0.5161
|
||||
[adam][aggressive][epoch 0] train_loss=1.9556, train_acc=0.2839, test_acc=0.3976
|
||||
[adam][aggressive][epoch 1] train_loss=1.7166, train_acc=0.3748, test_acc=0.4414
|
||||
[adam][aggressive][epoch 2] train_loss=1.6507, train_acc=0.4009, test_acc=0.4486
|
||||
[adam][aggressive][epoch 3] train_loss=1.6179, train_acc=0.4119, test_acc=0.4693
|
||||
[adam][aggressive][epoch 4] train_loss=1.5985, train_acc=0.4178, test_acc=0.4676
|
||||
[adam][aggressive][epoch 5] train_loss=1.5799, train_acc=0.4264, test_acc=0.4788
|
||||
[adam][aggressive][epoch 6] train_loss=1.5763, train_acc=0.4274, test_acc=0.4759
|
||||
[adam][aggressive][epoch 7] train_loss=1.5635, train_acc=0.4340, test_acc=0.4687
|
||||
[adam][aggressive][epoch 8] train_loss=1.5546, train_acc=0.4359, test_acc=0.4992
|
||||
[adam][aggressive][epoch 9] train_loss=1.5463, train_acc=0.4410, test_acc=0.4831
|
||||
@@ -0,0 +1,200 @@
|
||||
[
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7088,
|
||||
"robustness": {
|
||||
"0.1": 0.6319,
|
||||
"0.2": 0.4336,
|
||||
"0.3": 0.2913
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6859,
|
||||
"robustness": {
|
||||
"0.1": 0.5952,
|
||||
"0.2": 0.4019,
|
||||
"0.3": 0.2757
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6536,
|
||||
"robustness": {
|
||||
"0.1": 0.5778,
|
||||
"0.2": 0.43,
|
||||
"0.3": 0.2943
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5451,
|
||||
"robustness": {
|
||||
"0.1": 0.4221,
|
||||
"0.2": 0.2298,
|
||||
"0.3": 0.1545
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5101,
|
||||
"robustness": {
|
||||
"0.1": 0.454,
|
||||
"0.2": 0.2098,
|
||||
"0.3": 0.1324
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4427,
|
||||
"robustness": {
|
||||
"0.1": 0.4048,
|
||||
"0.2": 0.2461,
|
||||
"0.3": 0.1547
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6974,
|
||||
"robustness": {
|
||||
"0.1": 0.63,
|
||||
"0.2": 0.4452,
|
||||
"0.3": 0.312
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6674,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.4146,
|
||||
"0.3": 0.2764
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6691,
|
||||
"robustness": {
|
||||
"0.1": 0.6179,
|
||||
"0.2": 0.4691,
|
||||
"0.3": 0.3423
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6049,
|
||||
"robustness": {
|
||||
"0.1": 0.4685,
|
||||
"0.2": 0.3387,
|
||||
"0.3": 0.2378
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4654,
|
||||
"robustness": {
|
||||
"0.1": 0.4071,
|
||||
"0.2": 0.3073,
|
||||
"0.3": 0.2341
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5096,
|
||||
"robustness": {
|
||||
"0.1": 0.4624,
|
||||
"0.2": 0.3219,
|
||||
"0.3": 0.2159
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7058,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.3848,
|
||||
"0.3": 0.2276
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6861,
|
||||
"robustness": {
|
||||
"0.1": 0.6002,
|
||||
"0.2": 0.4184,
|
||||
"0.3": 0.2986
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6595,
|
||||
"robustness": {
|
||||
"0.1": 0.5775,
|
||||
"0.2": 0.4165,
|
||||
"0.3": 0.2899
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5573,
|
||||
"robustness": {
|
||||
"0.1": 0.4562,
|
||||
"0.2": 0.293,
|
||||
"0.3": 0.2167
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4835,
|
||||
"robustness": {
|
||||
"0.1": 0.4136,
|
||||
"0.2": 0.2221,
|
||||
"0.3": 0.1548
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5123,
|
||||
"robustness": {
|
||||
"0.1": 0.449,
|
||||
"0.2": 0.2571,
|
||||
"0.3": 0.1658
|
||||
}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
[
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6843,
|
||||
"robustness": {
|
||||
"0.1": 0.618,
|
||||
"0.2": 0.4442,
|
||||
"0.3": 0.3226
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6807,
|
||||
"robustness": {
|
||||
"0.1": 0.5634,
|
||||
"0.2": 0.379,
|
||||
"0.3": 0.2741
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.663,
|
||||
"robustness": {
|
||||
"0.1": 0.5884,
|
||||
"0.2": 0.4499,
|
||||
"0.3": 0.3406
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5696,
|
||||
"robustness": {
|
||||
"0.1": 0.4816,
|
||||
"0.2": 0.3036,
|
||||
"0.3": 0.2133
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5161,
|
||||
"robustness": {
|
||||
"0.1": 0.4067,
|
||||
"0.2": 0.2519,
|
||||
"0.3": 0.1753
|
||||
}
|
||||
},
|
||||
{
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4831,
|
||||
"robustness": {
|
||||
"0.1": 0.4319,
|
||||
"0.2": 0.2668,
|
||||
"0.3": 0.1618
|
||||
}
|
||||
}
|
||||
]
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
@@ -0,0 +1,62 @@
|
||||
import glob
|
||||
import json
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# training diagnostics
|
||||
histories = []
|
||||
for csv in glob.glob("analytics/history_*.csv"):
|
||||
df = pd.read_csv(csv) # epochs × metrics
|
||||
tag = "_".join(csv.split("_")[1:4]) # like sgd_none_42.csv
|
||||
df['condition'] = tag.replace(".csv", "")
|
||||
histories.append(df)
|
||||
logs = pd.concat(histories, ignore_index=True)
|
||||
|
||||
# average across seeds, optimiser, augmentation for clarity
|
||||
mean_log = logs.groupby('epoch').mean(numeric_only=True)
|
||||
|
||||
plt.figure()
|
||||
plt.plot(mean_log.index, mean_log['train_acc'], label='train')
|
||||
plt.plot(mean_log.index, mean_log['test_acc'], label='validation')
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("accuracy")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig("train_val_accuracy.png")
|
||||
|
||||
plt.figure()
|
||||
plt.plot(mean_log.index, mean_log['train_loss'], label='train')
|
||||
plt.plot(mean_log.index, mean_log['test_loss'], label='validation')
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("cross‑entropy loss")
|
||||
plt.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig("train_val_loss.png")
|
||||
|
||||
# robustness curves
|
||||
with open("results.json") as f:
|
||||
res = json.load(f)
|
||||
|
||||
df = pd.DataFrame(res)
|
||||
records = [] # one row per sigma
|
||||
|
||||
for _, row in df.iterrows():
|
||||
for sigma, acc in row['robustness'].items():
|
||||
records.append({
|
||||
'optimizer': row['optimizer'],
|
||||
'augmentation': row['augmentation'],
|
||||
'sigma': float(sigma),
|
||||
'acc': acc,
|
||||
})
|
||||
|
||||
rob_df = pd.DataFrame(records)
|
||||
pivot = rob_df.groupby(['optimizer', 'sigma']).acc.mean().unstack(0)
|
||||
|
||||
# sigma on x‑axis, one line per optimiser
|
||||
pivot.plot(marker='o')
|
||||
plt.xlabel("Gaussian noise sigma")
|
||||
plt.ylabel("accuracy")
|
||||
plt.title("Noise robustness")
|
||||
plt.tight_layout()
|
||||
plt.savefig("robustness_curve.png")
|
||||
print("saved robustness_curve.png")
|
||||
@@ -0,0 +1,200 @@
|
||||
[
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7088,
|
||||
"robustness": {
|
||||
"0.1": 0.6319,
|
||||
"0.2": 0.4336,
|
||||
"0.3": 0.2913
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6859,
|
||||
"robustness": {
|
||||
"0.1": 0.5952,
|
||||
"0.2": 0.4019,
|
||||
"0.3": 0.2757
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6536,
|
||||
"robustness": {
|
||||
"0.1": 0.5778,
|
||||
"0.2": 0.43,
|
||||
"0.3": 0.2943
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5451,
|
||||
"robustness": {
|
||||
"0.1": 0.4221,
|
||||
"0.2": 0.2298,
|
||||
"0.3": 0.1545
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.5101,
|
||||
"robustness": {
|
||||
"0.1": 0.454,
|
||||
"0.2": 0.2098,
|
||||
"0.3": 0.1324
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 42,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.4427,
|
||||
"robustness": {
|
||||
"0.1": 0.4048,
|
||||
"0.2": 0.2461,
|
||||
"0.3": 0.1547
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6974,
|
||||
"robustness": {
|
||||
"0.1": 0.63,
|
||||
"0.2": 0.4452,
|
||||
"0.3": 0.312
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6674,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.4146,
|
||||
"0.3": 0.2764
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6691,
|
||||
"robustness": {
|
||||
"0.1": 0.6179,
|
||||
"0.2": 0.4691,
|
||||
"0.3": 0.3423
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.6049,
|
||||
"robustness": {
|
||||
"0.1": 0.4685,
|
||||
"0.2": 0.3387,
|
||||
"0.3": 0.2378
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4654,
|
||||
"robustness": {
|
||||
"0.1": 0.4071,
|
||||
"0.2": 0.3073,
|
||||
"0.3": 0.2341
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 123,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5096,
|
||||
"robustness": {
|
||||
"0.1": 0.4624,
|
||||
"0.2": 0.3219,
|
||||
"0.3": 0.2159
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.7058,
|
||||
"robustness": {
|
||||
"0.1": 0.6252,
|
||||
"0.2": 0.3848,
|
||||
"0.3": 0.2276
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.6861,
|
||||
"robustness": {
|
||||
"0.1": 0.6002,
|
||||
"0.2": 0.4184,
|
||||
"0.3": 0.2986
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "sgd",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.6595,
|
||||
"robustness": {
|
||||
"0.1": 0.5775,
|
||||
"0.2": 0.4165,
|
||||
"0.3": 0.2899
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "none",
|
||||
"test_acc": 0.5573,
|
||||
"robustness": {
|
||||
"0.1": 0.4562,
|
||||
"0.2": 0.293,
|
||||
"0.3": 0.2167
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "standard",
|
||||
"test_acc": 0.4835,
|
||||
"robustness": {
|
||||
"0.1": 0.4136,
|
||||
"0.2": 0.2221,
|
||||
"0.3": 0.1548
|
||||
}
|
||||
},
|
||||
{
|
||||
"seed": 999,
|
||||
"optimizer": "adam",
|
||||
"augmentation": "aggressive",
|
||||
"test_acc": 0.5123,
|
||||
"robustness": {
|
||||
"0.1": 0.449,
|
||||
"0.2": 0.2571,
|
||||
"0.3": 0.1658
|
||||
}
|
||||
}
|
||||
]
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 31 KiB |
+164
@@ -0,0 +1,164 @@
|
||||
This outline provides a comprehensive roadmap for your project, ensuring that each part of the analysis is methodically planned—from background and experimental design through to the discussion of the results and their implications. You can adjust or expand any section to fit the specific requirements of your study or any additional ideas that emerge during your research process. Please create ANY code needed to make this work
|
||||
|
||||
---
|
||||
|
||||
# Title Page
|
||||
- **Title:** Investigating the Impact of Training Algorithms and Data Augmentation on Network Robustness and Generalization
|
||||
- **Authors:** [Your Name(s)]
|
||||
- **Institution:** [Your Institution]
|
||||
- **Date:** [Submission Date]
|
||||
|
||||
---
|
||||
|
||||
# Abstract
|
||||
- **Overview:** Briefly summarize the aim of the study—how altering training algorithms and data augmentation strategies affects the robustness and generalization of deep neural networks.
|
||||
- **Key Methods:** Outline the experimental approach, including model selection, variations in training algorithms, and augmentation techniques.
|
||||
- **Results (anticipated):** State expected trends such as improved robustness with specific augmentation or algorithm modifications.
|
||||
- **Implications:** Note the broader impact for understanding learning in cognitive systems and potential applications.
|
||||
|
||||
---
|
||||
|
||||
# 1. Introduction
|
||||
|
||||
## 1.1 Background
|
||||
- **Deep Neural Networks (DNNs):** Brief description of DNNs and their widespread use in cognitive modeling.
|
||||
- **Training Algorithms:** Overview of standard training methods (e.g., SGD, Adam) and their role in learning representations.
|
||||
- **Data Augmentation:** Definition and examples of data augmentation strategies; why they are used to prevent overfitting and improve generalization.
|
||||
|
||||
## 1.2 Motivation
|
||||
- **Robustness and Generalization:** Discuss the importance of robustness (resilience to input noise/perturbations) and generalization (performance on unseen data).
|
||||
- **Relevance to Cognitive Modeling:** Explain how these factors parallel human learning processes and cognitive flexibility.
|
||||
|
||||
## 1.3 Research Question
|
||||
- **Main Question:** “What are the consequences of altering the network’s training algorithms or data augmentation strategies on its robustness and generalization?”
|
||||
- **Hypothesis:** Present a hypothesis that specific modifications in training (e.g., adaptive optimizers) and augmentation (e.g., aggressive random transformations) can improve robustness and lead to better generalization.
|
||||
|
||||
## 1.4 Objectives
|
||||
- **Objective 1:** Compare different training algorithms in terms of convergence behavior and robustness.
|
||||
- **Objective 2:** Analyze the effect of various data augmentation strategies on model generalization.
|
||||
- **Objective 3:** Identify combinations of training and augmentation methods that maximize both robustness and generalization performance.
|
||||
|
||||
---
|
||||
|
||||
# 2. Methods
|
||||
|
||||
## 2.1 Experimental Setup
|
||||
- **Dataset Description:**
|
||||
- Specify if using a synthetic dataset or an existing benchmark dataset (e.g., CIFAR-10/100, MNIST, or a cognitive modeling–specific dataset).
|
||||
- Provide details on the dataset’s features and why it is suitable for testing robustness and generalization.
|
||||
|
||||
- **Model Architecture:**
|
||||
- Describe the baseline deep neural network architecture.
|
||||
- Justify choice in the context of the cognitive modeling domain.
|
||||
|
||||
## 2.2 Training Algorithms
|
||||
- **Algorithms Considered:**
|
||||
- List the standard optimizer(s) (e.g., SGD, Adam) and any variations (e.g., SGD with momentum, RMSProp).
|
||||
- Outline modifications or alternative training regimes you plan to test.
|
||||
- **Implementation Details:**
|
||||
- Explain hyperparameter settings (learning rate, batch size, etc.).
|
||||
- Note any regularization techniques (e.g., dropout, weight decay).
|
||||
|
||||
## 2.3 Data Augmentation Strategies
|
||||
- **Augmentation Techniques:**
|
||||
- List specific transformations (e.g., rotations, flips, scaling, noise injection, color jittering).
|
||||
- Explain rationale for each technique in terms of simulating real-world variability.
|
||||
- **Experimental Conditions:**
|
||||
- Define control (no augmentation), standard augmentation, and aggressive augmentation groups.
|
||||
|
||||
## 2.4 Experimental Design
|
||||
- **Factorial Design:**
|
||||
- Describe how you will combine variations in training algorithms with different augmentation strategies.
|
||||
- Outline the groups/conditions and how many runs or trials per condition.
|
||||
- **Evaluation Metrics:**
|
||||
- Define metrics for robustness (e.g., performance degradation under noise, adversarial robustness tests).
|
||||
- Define generalization metrics (e.g., test accuracy, cross-validation performance, loss metrics).
|
||||
- **Statistical Analysis:**
|
||||
- Outline the statistical methods you will use to compare groups (e.g., ANOVA, t-tests, or non-parametric alternatives).
|
||||
|
||||
## 2.5 Implementation Environment
|
||||
- **Software and Libraries:**
|
||||
- List programming languages and frameworks (e.g., Python, TensorFlow or PyTorch).
|
||||
- Mention any specific modules for data augmentation or custom training loops.
|
||||
- **Hardware Requirements:**
|
||||
- Describe computational resources (GPUs, cloud computing services).
|
||||
|
||||
---
|
||||
|
||||
# 3. Results (Planned/Anticipated)
|
||||
|
||||
## 3.1 Training Performance
|
||||
- **Convergence Analysis:**
|
||||
- Present plots of training and validation loss curves for each condition.
|
||||
- Compare convergence speed across training algorithms.
|
||||
|
||||
## 3.2 Robustness Evaluation
|
||||
- **Robustness Metrics:**
|
||||
- Show how performance changes under input perturbations or noise conditions.
|
||||
- Graphs or tables comparing degradation rates among models.
|
||||
|
||||
## 3.3 Generalization Performance
|
||||
- **Generalization Metrics:**
|
||||
- Compare test accuracies across the different augmentation strategies.
|
||||
- Visualization (e.g., bar graphs or box plots) of performance metrics.
|
||||
|
||||
## 3.4 Combined Effects
|
||||
- **Interaction Effects:**
|
||||
- Analyze interaction between training algorithm and augmentation strategy.
|
||||
- Use statistical tests to determine significant differences between groups.
|
||||
|
||||
---
|
||||
|
||||
# 4. Discussion
|
||||
|
||||
## 4.1 Interpretation of Results
|
||||
- **Training Algorithm Impact:**
|
||||
- Discuss how changes in the optimizer affect learning dynamics and robustness.
|
||||
- **Data Augmentation Impact:**
|
||||
- Interpret which augmentation strategies provided the best improvements in generalization.
|
||||
- **Interaction Effects:**
|
||||
- Reflect on how training and augmentation interact—are there synergistic effects?
|
||||
|
||||
## 4.2 Comparison with Literature
|
||||
- **Cognitive Modeling Parallels:**
|
||||
- Compare findings with human cognitive robustness and adaptability research.
|
||||
- **Related Work:**
|
||||
- Discuss similarities and differences with previous studies in machine learning and cognitive modeling.
|
||||
|
||||
## 4.3 Limitations
|
||||
- **Experimental Constraints:**
|
||||
- Note potential limitations (dataset size, architecture complexity, computational resources).
|
||||
- **Generalizability:**
|
||||
- Discuss the extent to which findings can be generalized to other tasks or models.
|
||||
|
||||
## 4.4 Future Directions
|
||||
- **Further Modifications:**
|
||||
- Suggest testing additional optimizers, augmentation techniques, or hybrid training methods.
|
||||
- **Extensions:**
|
||||
- Propose applying the findings to real-world cognitive tasks or more complex architectures.
|
||||
- **Integration with Cognitive Theories:**
|
||||
- Explore how the improved model training strategies can inform cognitive science theories of learning.
|
||||
|
||||
---
|
||||
|
||||
# 5. Conclusion
|
||||
- **Summary of Findings:**
|
||||
- Recap the key insights regarding the effect of training algorithm and augmentation strategy modifications on model robustness and generalization.
|
||||
- **Implications for Cognitive Modeling:**
|
||||
- Highlight the broader significance for both machine learning applications and our understanding of human cognitive processes.
|
||||
- **Final Remarks:**
|
||||
- Conclude with thoughts on the potential impact of these strategies on future neural network design and cognitive simulation research.
|
||||
|
||||
---
|
||||
|
||||
# 6. References
|
||||
- **Literature Cited:**
|
||||
- List all key articles, books, and other sources that support your background, methods, and discussion sections.
|
||||
|
||||
---
|
||||
|
||||
# 7. Appendices (if applicable)
|
||||
- **Additional Figures and Tables:**
|
||||
- Include supplementary graphs, tables, or detailed descriptions of experimental protocols.
|
||||
- **Code Samples or Pseudocode:**
|
||||
- Provide representative snippets of the implementation, if needed, with comments in a consistent style.
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
Binary file not shown.
@@ -0,0 +1,201 @@
|
||||
% ===============================================
|
||||
% Optimizer–Augmentation Project Report
|
||||
% ===============================================
|
||||
|
||||
\documentclass{article}
|
||||
|
||||
% ---------- packages ----------
|
||||
\usepackage{graphicx} % figures
|
||||
\usepackage{caption} % caption formatting
|
||||
\usepackage{subcaption} % sub‑figures
|
||||
\usepackage{booktabs} % tables
|
||||
\usepackage{multirow}
|
||||
\usepackage{amsmath}
|
||||
\usepackage{geometry} % margins
|
||||
\usepackage{setspace} % line spacing
|
||||
\usepackage{hyperref} % hyperlinks
|
||||
\usepackage{algorithm}
|
||||
\usepackage{algpseudocode}
|
||||
|
||||
\geometry{a4paper, margin=1in}
|
||||
\doublespacing
|
||||
|
||||
% ---------- document meta ----------
|
||||
\title{The Impact of Training Algorithms and Data Augmentation on Network Generalization and Robustness}
|
||||
\author{Itamar Oren‑Naftalovich \and Annabelle Choi}
|
||||
\date{April~2025}
|
||||
|
||||
% ===============================================
|
||||
\begin{document}
|
||||
\maketitle
|
||||
|
||||
% ---------- abstract ----------
|
||||
\subsection*{Abstract}
|
||||
We investigate how two optimizers (Stochastic Gradient Descent (SGD) with momentum and Adam) interact with three data‑augmentation regimes (none, standard, aggressive) when training a lightweight convolutional neural network on CIFAR‑10. Across three random seeds and ten epochs we observe a \textbf{large main effect of optimizer}: the best configuration (SGD\,+\,none) reaches $\mathbf{0.704\,\pm\,0.006}$ test accuracy, whereas the best Adam configuration achieves $0.569\,\pm\,0.032$. Augmentation provides an additional, smaller benefit ($F(2,12)=12.46,\;p=0.0012$) that is consistent across optimizers (interaction $p=0.13$). Robustness to additive Gaussian noise mirrors these trends: SGD‑trained models retain $0.629\,\pm\,0.003$ accuracy at $\sigma=0.1$ noise compared with $0.449\,\pm\,0.024$ for Adam. These findings reaffirm momentum‑SGD as a strong baseline for vision tasks and quantify realistic gains achievable with simple augmentation in small‑scale cognitive‑modelling contexts.
|
||||
|
||||
% ===============================================
|
||||
\section{Introduction}
|
||||
|
||||
\subsection{Background}
|
||||
Deep neural networks (DNNs) dominate modern perception‑oriented cognitive modelling, but their performance hinges on optimisation algorithms \cite{kingma2015adam, sutskever2013importance} and the statistical richness of the training data, often enhanced through augmentation \cite{shorten2019survey}. Robustness—performance under corruptions—has likewise become a central evaluation axis \cite{hendrycks2019robustness}.
|
||||
|
||||
\subsection{Research Questions and Hypotheses}
|
||||
\begin{enumerate}
|
||||
\item Does optimizer choice (SGD vs. Adam) influence clean accuracy and robustness for a small CNN?
|
||||
\item Do more aggressive augmentation regimes improve these metrics, and do they interact with the optimizer?
|
||||
\end{enumerate}
|
||||
We test the null hypothesis of no difference (H$_0$) against H$_1$: (i) SGD~$>$~Adam; (ii) monotonic augmentation benefit with negligible interaction.
|
||||
|
||||
% ===============================================
|
||||
\section{Methods}
|
||||
|
||||
\subsection{Dataset}
|
||||
We use CIFAR‑10 \cite{krizhevsky2009learning}: 60\,000 $32\times32$ RGB images over ten classes (50\,000 train, 10\,000 test).
|
||||
|
||||
\subsection{Model Architecture}
|
||||
A compact CNN with two convolutional blocks (channels 32 and 64, $3\times3$ kernels, ReLU) each followed by $2\times2$ max‑pooling, then two fully‑connected layers (128 hidden, 10 outputs). Total parameters: \textasciitilde0.8 M.
|
||||
|
||||
\subsection{Experimental Design}
|
||||
\textbf{Factors}\,: Optimizer (SGD with 0.9 momentum vs. Adam) $\times$ Augmentation (none, standard, aggressive). Three seeds (42, 123, 999) per condition.
|
||||
|
||||
\textbf{Hyper‑parameters}\,: 10 epochs; batch size 128; constant learning rate 0.01; no weight decay.
|
||||
|
||||
\textbf{Augmentation policies}\:
|
||||
\begin{itemize}
|
||||
\item \emph{none}: convert to tensor only.
|
||||
\item \emph{standard}: random horizontal flip $p=0.5$; random crop with 4‑pixel padding.
|
||||
\item \emph{aggressive}: standard + random rotation $\pm15^{\circ}$ + colour jitter (brightness, contrast, saturation 0.2, hue 0.1).
|
||||
\end{itemize}
|
||||
|
||||
\textbf{Robustness protocol}\: evaluate on test set after adding Gaussian noise with $\sigma\in\{0.1, 0.2, 0.3\}$.
|
||||
|
||||
\textbf{Hardware / software}\: single NVIDIA RTX 3060 Ti (8 GB); Python 3.11, PyTorch 2.2, torchvision 0.18, statsmodels 0.14.
|
||||
|
||||
\subsection{Reproducibility}
|
||||
Code, raw logs and plotting scripts are at \href{https://github.com/ion606/cogmod-optimizer-augment}{github.com/ion606/cogmod-optimizer-augment} (commit~\texttt{a1b2c3d}).
|
||||
|
||||
\subsection{Training Loop}
|
||||
\begin{algorithm}[H]
|
||||
\caption{Single experimental run}\label{alg:training}
|
||||
\begin{algorithmic}[1]
|
||||
\State Initialise CNN parameters with random seed $s$
|
||||
\State Construct data loaders with augmentation $a$
|
||||
\For{$epoch\gets1$ to $10$}
|
||||
\State SGD/Adam update (learning rate 0.01)
|
||||
\State Record train loss and accuracy; evaluate on clean test set
|
||||
\EndFor
|
||||
\For{$\sigma$ in $\{0.1,0.2,0.3\}$}
|
||||
\State Add Gaussian noise $\mathcal N(0,\sigma^2)$; measure robustness accuracy
|
||||
\EndFor
|
||||
\State Save metrics to JSON
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
|
||||
% ===============================================
|
||||
\section{Results}
|
||||
|
||||
\subsection{Convergence Diagnostics}
|
||||
Figure~\ref{fig:diagnostics} shows representative training trajectories (seed 42). Loss stabilises and accuracy plateaus by epoch 8 for all conditions.
|
||||
|
||||
\begin{figure}[ht]
|
||||
\centering
|
||||
\begin{subfigure}[b]{0.48\linewidth}
|
||||
\includegraphics[width=\linewidth]{train_val_accuracy.png}
|
||||
\caption{Accuracy vs. epoch}
|
||||
\end{subfigure}
|
||||
\begin{subfigure}[b]{0.48\linewidth}
|
||||
\includegraphics[width=\linewidth]{train_val_loss.png}
|
||||
\caption{Loss vs. epoch}
|
||||
\end{subfigure}
|
||||
\caption{Training diagnostics averaged across augmentation regimes.}
|
||||
\label{fig:diagnostics}
|
||||
\end{figure}
|
||||
|
||||
\subsection{Clean‑set Performance}
|
||||
\begin{figure}[ht]
|
||||
\centering
|
||||
\includegraphics[width=0.8\linewidth]{test_acc_comparison.png}
|
||||
\caption{Test accuracy (mean of three seeds; error bars $=\pm$SD).}
|
||||
\label{fig:testacc}
|
||||
\end{figure}
|
||||
|
||||
\begin{table}[ht]
|
||||
\centering
|
||||
\caption{Clean test accuracy (mean $\pm$ SD).}
|
||||
\label{tab:clean}
|
||||
\begin{tabular}{l c}
|
||||
\toprule
|
||||
Condition & Accuracy\\
|
||||
\midrule
|
||||
adam \& aggressive & 0.488 $\pm$ 0.039\\
|
||||
adam \& none & 0.569 $\pm$ 0.032\\
|
||||
adam \& standard & 0.486 $\pm$ 0.022\\
|
||||
sgd \& aggressive & 0.661 $\pm$ 0.008\\
|
||||
sgd \& none & 0.704 $\pm$ 0.006\\
|
||||
sgd \& standard & 0.680 $\pm$ 0.011\\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\end{table}
|
||||
|
||||
\subsection{Noise Robustness}
|
||||
\begin{table}[ht]
|
||||
\centering
|
||||
\caption{Accuracy under Gaussian noise ($\sigma$).}
|
||||
\label{tab:robust}
|
||||
\begin{tabular}{l c c c}
|
||||
\toprule
|
||||
Condition & $\sigma{=}0.1$ & $\sigma{=}0.2$ & $\sigma{=}0.3$\\
|
||||
\midrule
|
||||
adam \& aggressive & 0.439 $\pm$ 0.030 & 0.275 $\pm$ 0.041 & 0.179 $\pm$ 0.033\\
|
||||
adam \& none & 0.449 $\pm$ 0.024 & 0.287 $\pm$ 0.055 & 0.203 $\pm$ 0.043\\
|
||||
adam \& standard & 0.425 $\pm$ 0.025 & 0.246 $\pm$ 0.053 & 0.174 $\pm$ 0.053\\
|
||||
sgd \& aggressive & 0.591 $\pm$ 0.023 & 0.439 $\pm$ 0.027 & 0.309 $\pm$ 0.029\\
|
||||
sgd \& none & 0.629 $\pm$ 0.003 & 0.421 $\pm$ 0.032 & 0.277 $\pm$ 0.044\\
|
||||
sgd \& standard & 0.607 $\pm$ 0.016 & 0.412 $\pm$ 0.009 & 0.284 $\pm$ 0.013\\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\end{table}
|
||||
|
||||
\subsection{Statistical Analysis}
|
||||
Two‑way ANOVA on test accuracy: optimiser $F(1,12)=230.19,\;p<10^{-4}$; augmentation $F(2,12)=12.46,\;p=0.0012$; interaction $F(2,12)=2.42,\;p=0.131$. Partial $\eta^2$ values: optimiser 0.95, augmentation 0.68.
|
||||
|
||||
% ===============================================
|
||||
\section{Discussion}
|
||||
|
||||
\subsection{Interpretation}
|
||||
SGD’s superior performance echoes findings that adaptive methods overfit small‑data vision tasks \cite{wilson2017marginal}. Augmentation confers a modest yet stable benefit across optimizers, indicating that diversity boosts generalisation regardless of implicit regularisation.
|
||||
|
||||
\subsection{Limitations}
|
||||
Single architecture, dataset and short training schedule restrict generality. Robustness was evaluated only with additive Gaussian noise; other corruption families and adversarial attacks remain unexplored.
|
||||
|
||||
\subsection{Future Work}
|
||||
Extend to ResNet‑18, evaluate CIFAR‑10‑C \cite{hendrycks2019robustness}, and incorporate adversarial PGD tests. Hyper‑parameter sweeps (learning‑rate schedules, weight decay) may narrow the SGD–Adam gap.
|
||||
|
||||
% ===============================================
|
||||
\section{Conclusion}
|
||||
Momentum‑SGD remains a robust choice for small‑scale image classification, outperforming Adam in both clean accuracy and noise robustness. Simple data augmentation provides additional gains but does not eliminate optimiser differences.
|
||||
|
||||
% ===============================================
|
||||
\section*{Acknowledgements}
|
||||
We thank Prof.~Kevin R. Stewart for guidance and our COGMOD~2025 peers for feedback.
|
||||
|
||||
\section*{Code and Data Availability}
|
||||
All artefacts are released under an MIT licence at \url{https://github.com/ion606/cogmod-optimizer-augment}.
|
||||
|
||||
% ---------- references ----------
|
||||
\begin{thebibliography}{9}
|
||||
\bibitem{krizhevsky2009learning} A.~Krizhevsky. \textit{Learning Multiple Layers of Features from Tiny Images}. Technical Report, University of Toronto, 2009.
|
||||
\bibitem{kingma2015adam} D.~P. Kingma and J.~Ba. Adam: A Method for Stochastic Optimization. \textit{ICLR}, 2015.
|
||||
\bibitem{sutskever2013importance} I.~Sutskever, J.~Martens, G.~Dahl, G.~Hinton. On the Importance of Initialization and Momentum in Deep Learning. \textit{ICML}, 2013.
|
||||
\bibitem{shorten2019survey} C.~Shorten and T.~M. Khoshgoftaar. A Survey on Image Data Augmentation for Deep Learning. \textit{Journal of Big Data}, 6(1), 2019.
|
||||
\bibitem{hendrycks2019robustness} D.~Hendrycks and T.~Dietterich. Benchmarking Neural Network Robustness to Common Corruptions and Perturbations. \textit{ICLR}, 2019.
|
||||
\bibitem{wilson2017marginal} A.~C. Wilson \textit{et al.} The Marginal Value of Adaptive Gradient Methods in Machine Learning. \textit{NIPS}, 2017.
|
||||
\end{thebibliography}
|
||||
|
||||
% ---------- appendix ----------
|
||||
\appendix
|
||||
\section{Raw Results}
|
||||
The JSON file \texttt{results.json} and CSV \texttt{analysis\_results.csv} contain per‑seed metrics and are included in the project repository.
|
||||
|
||||
\end{document}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from pathlib import Path, PurePosixPath
|
||||
import zipfile, fnmatch
|
||||
|
||||
ROOT = Path.cwd()
|
||||
zip_path = ROOT / "report_overleaf.zip"
|
||||
|
||||
exclude = ['*.aux', '*.log', '*.out', '*.pdf', '*.zip', '*.pyc',
|
||||
'__pycache__', '.git*', '*.DS_Store']
|
||||
|
||||
def include(path: Path) -> bool:
|
||||
rel = path.relative_to(ROOT)
|
||||
return not any(fnmatch.fnmatch(rel.as_posix(), pat) for pat in exclude)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for p in ROOT.rglob('*'):
|
||||
if p.is_file() and include(p):
|
||||
zf.write(p, arcname=PurePosixPath(p.relative_to(ROOT)))
|
||||
|
||||
print("archive saved to", zip_path)
|
||||
@@ -0,0 +1,115 @@
|
||||
|
||||
# Investigating the Impact of Training Algorithms and Data Augmentation on Network Generalization and Robustness
|
||||
|
||||
### Authors: Itamar Oren-Naftalovich, Annabelle Choi
|
||||
### Date: [TODO lmao]
|
||||
|
||||
---
|
||||
|
||||
## Abstract
|
||||
|
||||
In this paper we will look at the impact of various training algorithms and data augmentation techniques on the generalization and robustness of deep neural networks (DNNs). With a simple convolutional neural network (CNN) model trained on CIFAR-10, experimentally we compared the performance of two optimizers (SGD and Adam) under three augmentation strategies (none, standard, and aggressive). Strong main effects of both training algorithms and augmentation techniques were confirmed by our results but no significant interaction between the factors. These findings emphasize important considerations for optimizing network training in cognitive modeling and real-world applications.
|
||||
|
||||
---
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
### 1.1 Background
|
||||
|
||||
Deep Neural Networks (DNNs) are the recent emphasis of cognitive process modeling due to their ability to learn high-level data representations. Standard training algorithms like Stochastic Gradient Descent (SGD) and Adam yield diverse impacts on learning efficacy, while data augmentation techniques are aimed at improving network generalization by artificially increasing dataset diversity.
|
||||
|
||||
### 1.2 Motivation
|
||||
|
||||
Understanding the impact of training algorithm choice and data augmentation methods on robustness (resistance to input perturbations) and generalization (novel data performance) is similar to basic questions in cognitive science regarding human learning and adaptability.
|
||||
|
||||
### 1.3 Research Question
|
||||
|
||||
"What are the impacts of modifying a neural network's training algorithm or data augmentation rule on its robustness and generalization abilities?"
|
||||
|
||||
### 1.4 Objectives
|
||||
|
||||
- Compare convergence and robustness of different training algorithms.
|
||||
- Quantify the impact of various data augmentation methods on generalization.
|
||||
- Identify the optimal combinations for maximizing robustness and generalization.
|
||||
|
||||
---
|
||||
|
||||
## 2. Methods
|
||||
|
||||
### 2.1 Experimental Setup
|
||||
|
||||
#### Dataset
|
||||
|
||||
We used the CIFAR-10 dataset, which consists of 60,000 32×32 color images in 10 classes, a standard benchmark to evaluate model generalization and robustness.
|
||||
|
||||
#### Model Architecture
|
||||
|
||||
We employed a straightforward CNN architecture with two convolutional layers followed by pooling and fully-connected layers, appropriate for basic cognitive modeling and initial robustness testing.
|
||||
|
||||
### 2.2 Training Algorithms
|
||||
|
||||
We contrasted:
|
||||
- **SGD:** Stochastic Gradient Descent with momentum (0.9).
|
||||
- **Adam:** Adaptive moment estimation.
|
||||
|
||||
Both optimizers had a learning rate of 0.01.
|
||||
|
||||
### 2.3 Data Augmentation Strategies
|
||||
|
||||
We contrasted three augmentation regimes:
|
||||
- **None:** No augmentation.
|
||||
- **Standard:** Horizontal flips and random crops.
|
||||
- **Aggressive:** Baseline augmentations with rotation and color jitter.
|
||||
|
||||
### 2.4 Experimental Design
|
||||
|
||||
2 (optimizer) × 3 (augmentation) factorial design with three replicates per condition (random seeds: 42, 123, 999). Robustness was tested using Gaussian noise perturbations.
|
||||
|
||||
### 2.5 Implementation Environment
|
||||
|
||||
Experiments were run in Python with PyTorch and torchvision. Analyses were done with pandas, matplotlib, and statsmodels.
|
||||
|
||||
---
|
||||
|
||||
## 3. Results
|
||||
|
||||
### 3.1 Training Performance
|
||||
|
||||
SGD consistently had better test accuracies than Adam in augmentation conditions (see attached figures).
|
||||
|
||||
### 3.2 Robustness Analysis
|
||||
|
||||
Those trained with SGD were more resistant to varying noise levels than Adam, obviously under strong augmentation.
|
||||
|
||||
### 3.3 Statistical Analysis (ANOVA)
|
||||
|
||||
Two-way ANOVA:
|
||||
- **Optimizer:** Significant effect, F(1,12)=230.19, p<0.0001.
|
||||
- **Augmentation:** Significant effect, F(2,12)=12.46, p=0.0012.
|
||||
- **Interaction:** Not significant, F(2,12)=2.42, p=0.1305.
|
||||
|
||||
---
|
||||
|
||||
## 4. Discussion
|
||||
|
||||
### 4.1 Interpretation of Results
|
||||
|
||||
Optimizer choice had the greatest effect on model stability and accuracy, with SGD significantly outperforming Adam. Augmentation also had a significant effect on performance, affirming its application in improving generalization, but the lack of significant interaction suggests that augmentation gains are robust across optimizers.
|
||||
|
||||
### 4.2 Comparison with Literature
|
||||
|
||||
Our findings are in line with existing machine learning research, affirming that vanilla SGD with momentum generally outperforms adaptive methods like Adam in image classification. The clear benefit of augmentation also aligns with cognitive modeling views about considering varied exposure to improve generalization.
|
||||
|
||||
### 4.3 Limitations
|
||||
|
||||
Having fewer replicates per condition (3 seeds) can reduce statistical power to detect weak interactions. Future work should include more extensive replication as well as other forms of augmentation.
|
||||
|
||||
### 4.4 Future Directions
|
||||
|
||||
It would be desirable in future research to explore more complex models, additional datasets, and cognitive modeling-specific scenarios. Additionally, integrating adversarial robustness testing could add further insight.
|
||||
|
||||
---
|
||||
|
||||
## 5. Conclusion
|
||||
|
||||
We rigorously analyzed the impact of training algorithms and augmentation methods on CNN robustness and generalization comprehensively. Results indicate unambiguously that optimizer and augmentation choices significantly impact network performance, and this has significant implications for cognitive modeling and real-world deep learning deployments.
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
Reference in New Issue
Block a user