使用torch的自动微分实现自定义函数优化

前言

目的是利用torch已经有的自动微分机制,进行参数迭代更新,就不用自己写代码算了。

1. 待优化函数

y = 10 × ( x 1 + x 2 − 5 ) 2 + ( x 1 − x 2 ) 2 y=10\times(x_1+x_2-5)^2+(x_1-x_2)^2 y=10×(x1+x25)2+(x1x2)2

1.1 解释

这里我们把[10,5]看成inputs,整个函数就是model, [x1,x2]就是需要迭代优化的参数。我们要求使得y=0时的参数。根据我们的先验知识,我们期望输出的结果是 5 2 \frac{5}{2} 25, 5 2 \frac{5}{2} 25

2. 代码

import torch
from torch import nn

# y = 10*(x1+x2-5)^2 + (x1-x2)^2
class Func(nn.Module):
    def __init__(self, size=2):
        super(Func, self).__init__()
        params = torch.rand((size,1),requires_grad=True)
        self.params = nn.Parameter(params)
    
    def forward(self, inputs):
        y = inputs[0] * torch.pow(self.params[0]+self.params[1]-inputs[1],2)\
        + torch.pow(self.params[0]-self.params[1],2)
        return y

class cusLoss(nn.Module):
    def __init__(self):
        super(cusLoss, self).__init__()
    
    def forward(self,y_pred, y_true):
        return torch.abs(y_true-y_pred)

model = Func(size=2)

optimizer = torch.optim.AdamW(model.parameters(),lr=1,weight_decay=1e-5)
loss_func = cusLoss()
x = torch.tensor([10,5])

for i in range(400):
    y_pred = model(x)
    loss = loss_func(y_pred,0)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("loss: ", loss.item())

for item in model.parameters():
    print(item)

lr从10调到1,感觉loss就比较好了。

3. 结果

在这里插入图片描述

loss:  156.83274841308594
loss:  38.49453353881836
loss:  0.23079237341880798
loss:  18.87002944946289
loss:  47.266353607177734
loss:  53.96223449707031
loss:  40.30459976196289
loss:  19.8029842376709
loss:  4.426110744476318
loss:  0.13327637314796448
loss:  5.848392009735107
loss:  15.253518104553223
loss:  21.365623474121094
loss:  20.824350357055664
loss:  14.78664779663086
loss:  7.028133392333984
loss:  1.4270392656326294
loss:  0.08896948397159576
loss:  2.5925991535186768
loss:  6.561254978179932
loss:  9.249794960021973
loss:  9.138370513916016
loss:  6.541881561279297
loss:  3.0825929641723633
loss:  0.5860095620155334
loss:  0.07377026975154877
loss:  1.3241856098175049
loss:  3.1676881313323975
loss:  4.294349193572998
loss:  4.038957118988037
loss:  2.664410352706909
loss:  1.0545320510864258
loss:  0.0981285348534584
loss:  0.16388599574565887
loss:  0.9541118741035461
loss:  1.7817351818084717
loss:  2.053565502166748
loss:  1.6271475553512573
loss:  0.8293525576591492
loss:  0.17544522881507874
loss:  0.011017587967216969
loss:  0.313961386680603
loss:  0.7632603049278259
loss:  0.9963911771774292
loss:  0.8582224249839783
loss:  0.4731455445289612
loss:  0.116444431245327
loss:  0.0019159330986440182
loss:  0.14314040541648865
loss:  0.3744982182979584
loss:  0.49575677514076233
loss:  0.4202014207839966
loss:  0.2189445197582245
loss:  0.044234275817871094
loss:  0.005240241996943951
loss:  0.09423969686031342
loss:  0.21066127717494965
loss:  0.25115442276000977
loss:  0.18814465403556824
loss:  0.07860895991325378
loss:  0.0066448356956243515
loss:  0.01395349856466055
loss:  0.07405033707618713
loss:  0.1240086629986763
loss:  0.12016353011131287
loss:  0.0697774738073349
loss:  0.01685400679707527
loss:  0.0004727207124233246
loss:  0.02368409000337124
loss:  0.05685136467218399
loss:  0.06748046725988388
loss:  0.047708846628665924
loss:  0.01680355705320835
loss:  0.0003963433555327356
loss:  0.007644111756235361
loss:  0.02610907331109047
loss:  0.0360269770026207
loss:  0.028617529198527336
loss:  0.011885224841535091
loss:  0.0007842279155738652
loss:  0.0028550983406603336
loss:  0.012763937003910542
loss:  0.019207362085580826
loss:  0.016054006293416023
loss:  0.006997825112193823
loss:  0.0005394043400883675
loss:  0.0014044318813830614
loss:  0.006836464628577232
loss:  0.010403113439679146
loss:  0.008586933836340904
loss:  0.003567643463611603
loss:  0.00020178437989670783
loss:  0.0009794053621590137
loss:  0.004060074221342802
loss:  0.005762426648288965
loss:  0.004401565529406071
loss:  0.001568805892020464
loss:  2.635764940350782e-05
loss:  0.0008179567521438003
loss:  0.002542160451412201
loss:  0.0031480903271585703
loss:  0.0020700229797512293
loss:  0.000531529716681689
loss:  1.5086681742104702e-05
loss:  0.0007311901426874101
loss:  0.0016244511352851987
loss:  0.0016573555767536163
loss:  0.0008591370424255729
loss:  0.00011040566459996626
loss:  9.262182720704004e-05
loss:  0.0006213163724169135
loss:  0.0009869090281426907
loss:  0.0007818497833795846
loss:  0.0002726784732658416
loss:  2.5415104119019816e-06
loss:  0.00017136213136836886
loss:  0.00048359768697991967
loss:  0.0005475623183883727
loss:  0.00030850598705001175
loss:  4.818948218598962e-05
loss:  2.336039688088931e-05
loss:  0.00019348246860317886
loss:  0.0003168497933074832
loss:  0.00024830293841660023
loss:  8.038982196012512e-05
loss:  2.252596686957986e-08
loss:  6.616349855903536e-05
loss:  0.00016678131942171603
loss:  0.00017271166143473238
loss:  8.399917714996263e-05
loss:  7.142824415495852e-06
loss:  1.6916283129830845e-05
loss:  7.81318376539275e-05
loss:  0.00010478033073013648
loss:  6.641951040364802e-05
loss:  1.2883243471151218e-05
loss:  3.119921984762186e-06
loss:  3.588973049772903e-05
loss:  6.165451486594975e-05
loss:  4.818914021598175e-05
loss:  1.456955124012893e-05
loss:  7.381692057606415e-08
loss:  1.4879115951771382e-05
loss:  3.3393916964996606e-05
loss:  3.068727164645679e-05
loss:  1.1735791304090526e-05
loss:  1.3357407624425832e-07
loss:  6.764242698409362e-06
loss:  1.9035425793845206e-05
loss:  1.9952953152824193e-05
loss:  9.178495929518249e-06
loss:  4.976278091817221e-07
loss:  2.7312034944770858e-06
loss:  1.0013219252869021e-05
loss:  1.1602534868870862e-05
loss:  5.797338872071123e-06
loss:  4.113908858016657e-07
loss:  1.4342235772346612e-06
loss:  5.989348665025318e-06
loss:  7.375299446721328e-06
loss:  3.998892680101562e-06
loss:  4.2068927541549783e-07
loss:  6.325563504105958e-07
loss:  3.202781499567209e-06
loss:  4.100205842405558e-06
loss:  2.1982234557071934e-06
loss:  1.9405547391215805e-07
loss:  4.73600323402934e-07
loss:  2.1346897938201437e-06
loss:  2.6749878543341765e-06
loss:  1.4465717868006323e-06
loss:  1.4582843732569017e-07
loss:  2.3943954374772147e-07
loss:  1.1561919563973788e-06
loss:  1.4047566310182447e-06
loss:  6.778419106012734e-07
loss:  2.8631802706513554e-08
loss:  2.509316345822299e-07
loss:  8.661324955028249e-07
loss:  9.58465079747839e-07
loss:  4.4112675823271275e-07
loss:  2.068622961814981e-08
loss:  1.3861813386029098e-07
loss:  4.55244389740983e-07
loss:  4.510482085606782e-07
loss:  1.5433130329256528e-07
loss:  1.1869474292325322e-09
loss:  1.6850668771439814e-07
loss:  3.73103375750361e-07
loss:  3.215137098777632e-07
loss:  1.0316830412193667e-07
loss:  1.460307430534158e-10
loss:  8.736839163248078e-08
loss:  1.7326237866654992e-07
loss:  1.1830303492388339e-07
loss:  1.6191336271731416e-08
loss:  1.859552867244929e-08
loss:  1.1124279808427673e-07
loss:  1.5508044270973187e-07
loss:  9.294950586991035e-08
loss:  1.3275212040753104e-08
loss:  7.75889930082485e-09
loss:  4.7226023980329046e-08
loss:  5.3971746183378855e-08
loss:  1.7632885374041507e-08
loss:  5.916831469221506e-10
loss:  3.007107807206921e-08
loss:  6.265560159590677e-08
loss:  5.392854518504464e-08
loss:  1.8428409020998515e-08
loss:  5.690026227966882e-11
loss:  9.340737960883416e-09
loss:  1.8043010641122237e-08
loss:  9.34875288294279e-09
loss:  7.190692485892214e-11
loss:  8.770314252615208e-09
loss:  2.557277412051917e-08
loss:  2.8033127819071524e-08
loss:  1.384620418320992e-08
loss:  1.545231498312205e-09
loss:  1.1041265679523349e-09
loss:  4.812136467080563e-09
loss:  3.2833327168191317e-09
loss:  2.0691004465334117e-11
loss:  3.639399892563233e-09
loss:  1.146554495790042e-08
loss:  1.4196075426298194e-08
loss:  8.747122137719998e-09
loss:  2.0532411326712463e-09
loss:  1.5973000699887052e-11
loss:  9.15179043659009e-10
loss:  5.866809260623995e-10
loss:  4.001776687800884e-11
loss:  2.049148406513268e-09
loss:  6.149605269456515e-09
loss:  7.649362032680074e-09
loss:  5.2387463256309275e-09
loss:  1.7826664588938002e-09
loss:  1.1164047464262694e-10
loss:  3.68913788406644e-11
loss:  9.606537787476555e-12
loss:  2.2788526621297933e-10
loss:  1.7835191101767123e-09
loss:  4.012292720290134e-09
loss:  4.403375442052493e-09
loss:  3.284696958871791e-09
loss:  1.3110934560245369e-09
loss:  3.283275873400271e-10
loss:  8.236611392931081e-11
loss:  1.1164047464262694e-10
loss:  5.821334525535349e-10
loss:  1.5370460459962487e-09
loss:  2.6284396881237626e-09
loss:  2.9468196771631483e-09
loss:  2.1852883946849033e-09
loss:  1.1010001799149904e-09
loss:  5.125002644490451e-10
loss:  3.2883917810977437e-10
loss:  3.851710062008351e-10
loss:  7.372022992058191e-10
loss:  1.537557636765996e-09
loss:  2.0465904526645318e-09
loss:  2.0465904526645318e-09
loss:  1.537273419671692e-09
loss:  9.09551545191789e-10
loss:  5.821334525535349e-10
loss:  5.821334525535349e-10
loss:  6.571099220309407e-10
loss:  9.09551545191789e-10
loss:  1.3097292139718775e-09
loss:  1.657781467656605e-09
loss:  1.537557636765996e-09
loss:  1.310183961322764e-09
loss:  9.100062925426755e-10
loss:  5.825881999044213e-10
loss:  5.825881999044213e-10
loss:  8.21046342025511e-10
loss:  1.100545432564104e-09
loss:  1.4210854715202004e-09
loss:  1.5371028894151095e-09
loss:  1.3097292139718775e-09
loss:  9.09551545191789e-10
loss:  7.369180821115151e-10
loss:  6.573372957063839e-10
loss:  7.369180821115151e-10
loss:  9.09551545191789e-10
loss:  1.3097292139718775e-09
loss:  1.3097292139718775e-09
loss:  1.3097292139718775e-09
loss:  1.0027179087046534e-09
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.369180821115151e-10
loss:  9.097220754483715e-10
loss:  1.1007159628206864e-09
loss:  1.2030341167701408e-09
loss:  1.100545432564104e-09
loss:  1.0027179087046534e-09
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  1.100545432564104e-09
loss:  1.100545432564104e-09
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  1.100545432564104e-09
loss:  1.0027179087046534e-09
loss:  9.094947017729282e-10
loss:  8.208189683500677e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  1.100545432564104e-09
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  8.208189683500677e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  8.208189683500677e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  8.208189683500677e-10
loss:  8.208189683500677e-10
loss:  8.208189683500677e-10
loss:  8.208189683500677e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  5.821334525535349e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  9.09551545191789e-10
loss:  7.366907084360719e-10
loss:  6.571099220309407e-10
loss:  5.821334525535349e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  9.09551545191789e-10
loss:  7.367475518549327e-10
loss:  7.366907084360719e-10
loss:  7.367475518549327e-10
loss:  5.821334525535349e-10
loss:  7.367475518549327e-10
loss:  7.366907084360719e-10
loss:  7.366907084360719e-10
loss:  7.366907084360719e-10
loss:  7.366907084360719e-10
loss:  7.366907084360719e-10
loss:  7.366907084360719e-10
loss:  7.367475518549327e-10
loss:  5.821334525535349e-10
loss:  5.821334525535349e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  8.208189683500677e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  5.821334525535349e-10
loss:  5.821334525535349e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  7.367475518549327e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
loss:  6.571099220309407e-10
Parameter containing:
tensor([[2.5000],
        [2.5000]], requires_grad=True)

猜你喜欢

转载自blog.csdn.net/sdhdsf132452/article/details/131241303
今日推荐