In [0]:
# Change directory to VSCode workspace root so that relative path loads work correctly. Turn this addition off with the DataScience.changeDirOnImportExport setting
# ms-python.python added
import os
try:
	os.chdir(os.path.join(os.getcwd(), '../..'))
	print(os.getcwd())
except:
	pass
In [1]:
import numpy as np
import scipy.stats
import torch
from torch.distributions import constraints
from matplotlib import pyplot as plt

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoContinuous
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate, infer_discrete

from content.prepare.tools import obj_dic
pyro.enable_validation(True)

#### alternative configurations
auto_guide = False
In [2]:
def gen_data(N):
    a = np.random.normal(0, 20, (1,2))
    p = a + np.random.normal(0, 1, (N, 2))
    bestμ = np.mean(p, axis=0)
    bestσ = np.std(p, axis=0)
    return torch.tensor(p, dtype=torch.float), obj_dic(locals())
    
data, gt = gen_data(100)
In [3]:
plt.scatter(data[:,0], data[:,1], label="Observations")
plt.legend()
Out[3]:
<matplotlib.legend.Legend at 0x7fa50c9d4a90>
In [4]:
pyro.clear_param_store()

def model(data):
    a = pyro.sample('a', dist.MultivariateNormal(torch.zeros(2), torch.eye(2)*100**2))
    with pyro.plate('data', len(data)):
        pyro.sample('obs', dist.MultivariateNormal(a, torch.eye(2)), obs=data)
        

def custom_guide(data):
    N = data.shape[0]
     = pyro.param('qaμ', torch.tensor([0, 0], dtype=torch.float)) #torch.ones(2)*0)
     = pyro.param('qaσ', torch.tensor(10.**2), constraint=constraints.positive)
    a = pyro.sample('a', dist.MultivariateNormal(, torch.eye(2)*))
In [5]:
if auto_guide:
    guide = AutoDiagonalNormal(model)
else:
    guide = custom_guide

optimizer = pyro.optim.Adam({'lr': 1, 'betas': [0.8, 0.99]})
elbo = Trace_ELBO(max_plate_nesting=1)

svi = SVI(model, guide, optimizer, loss=elbo)
In [6]:
n_steps = 200

for step in range(n_steps):
    v = svi.step(data)
    print(v)
71216.87931060791
75668.71525859833
21563.728432178497
26858.80811882019
21837.197683811188
19211.849675655365
20734.990266799927
18838.280572652817
17540.236302375793
14454.154158353806
13971.398915529251
10249.96979689598
11772.31806731224
7378.357187986374
9357.981699228287
6921.208762645721
5960.253732442856
4641.260904550552
5366.632966518402
3324.0061869621277
2915.7325859069824
2505.3869891166687
1619.0072875022888
2421.3061327934265
1994.7780888080597
2000.9173135757446
1263.2282831668854
1172.0913430452347
1604.9619064331055
600.2520158290863
668.2877774238586
768.4675137996674
351.1283600330353
370.3160216808319
442.8689516186714
349.96266531944275
569.8004233837128
300.7386665344238
374.49730944633484
317.0641767978668
291.8995957374573
320.81824016571045
301.07752871513367
305.5737340450287
296.7695165872574
398.60035705566406
354.8935046195984
306.5265531539917
298.12584352493286
305.09562134742737
322.5237789154053
304.71664810180664
286.2614486217499
285.76308381557465
299.6047774553299
380.3847393989563
324.2028069496155
285.14400029182434
333.67971658706665
292.980850815773
350.542014837265
289.9520733356476
312.7325928211212
321.9056673049927
293.5115839242935
306.30480670928955
291.58669912815094
341.50536155700684
306.56704092025757
287.4663212299347
295.9534480571747
316.0512729883194
291.1807496547699
346.32222533226013
289.15401804447174
321.2121436595917
295.7508546113968
305.3177958726883
371.9384562969208
368.9077298641205
323.85498559474945
306.77761697769165
319.0337806940079
285.1592757701874
295.4551000595093
286.08898401260376
296.658029794693
309.5172482728958
288.0212106704712
315.73852014541626
294.7038822174072
318.6416130065918
321.12588238716125
299.5009319782257
287.9149479866028
327.73888969421387
341.5383470058441
296.3519539833069
286.74657130241394
300.485392332077
338.65857100486755
338.90750074386597
294.40932297706604
303.23690271377563
417.24887561798096
306.97641706466675
289.65798234939575
333.6621935367584
294.5790011882782
307.35924458503723
286.5548529624939
311.5830148458481
319.88227581977844
317.5451102256775
353.6185803413391
312.17556846141815
289.36925864219666
303.9731857776642
350.5454956293106
323.66881573200226
327.75743436813354
299.50165152549744
320.5415052175522
325.30294370651245
286.6506235599518
286.54176807403564
290.5851591825485
297.714656829834
299.38247442245483
305.3342881202698
285.578307390213
295.1361999511719
292.17748522758484
292.82996475696564
290.93867015838623
287.4499887228012
286.1421823501587
303.8106632232666
293.7601869106293
286.9966766834259
308.9531216621399
290.36455750465393
290.3004322052002
290.6210448741913
325.08542823791504
331.07824182510376
291.20895552635193
294.6995973587036
309.63303089141846
297.3207414150238
302.7495995759964
286.87888073921204
289.6363170146942
287.1268039941788
312.8485028743744
300.23976278305054
317.4951386451721
318.38526701927185
292.6602153778076
287.56225085258484
305.2978632450104
312.478386759758
307.2999531030655
353.8740358352661
286.0489504337311
301.3817026615143
289.9423887729645
287.5336170196533
331.18577790260315
289.05829095840454
292.7839480638504
315.488404750824
296.7764182090759
314.08885180950165
293.1702914237976
302.52917516231537
291.3313961029053
298.01973390579224
295.40176272392273
291.06069922447205
294.254269361496
299.0681276321411
320.96804213523865
286.7587729692459
287.30614149570465
335.83106422424316
297.7001736164093
287.3528277873993
289.07236981391907
289.3441022634506
289.50313925743103
295.1135792732239
302.5033619403839
296.73228669166565
297.3661115169525
290.49700927734375
292.8470211029053
290.3405203819275
293.2078673839569
293.03975892066956
In [7]:
def p2d(name):
    if auto_guide:
        p = guide.get_posterior().mean.detach().numpy()
    else:
        p = pyro.param(name).detach().numpy()
    plt.scatter(p[0], p[1], marker='x', label='Estimated mean')

plt.scatter(data[:,0], data[:,1], alpha=0.2, label='Observations')
plt.scatter(gt.bestμ[0], gt.bestμ[1], marker='+', label='Actual mean')
plt.scatter(gt.a[0,0], gt.a[0,1], marker='+', label='Source mean')
p2d('qaμ')
plt.legend()
if auto_guide:
    v = guide.get_posterior().variance.detach().numpy()
else:
    v = pyro.param('qaσ').detach().numpy()
print('variance on the mean:', v)
#print(pyro.param('uα').detach().numpy(), pyro.param('uβ').detach().numpy())
variance on the mean: 0.11611741