/* Logistic regression - Challenger in Picat. From https://www.zinkov.com/posts/2012-06-27-why-prob-programming-matters/ "Logistic Regression" """ Logistic Regression can be seen as a generalization of Linear Regression where the output is transformed to lie between 0 and 1. This model only differs from the previous one by a single line, illustrating that adding this complexity does not require starting from scratch. The point with probabilistic programming is you are able to explore slightly more complex models very easily. """ From https://www.stat.ubc.ca/~bouchard/courses/stat520-sp2014-15/lecture/2015/02/27/notes-lecture3.html x = 66,70,69,68,67,72,73,70,57,63,70,78,67,53,67,75,70,81,76,79,75,76,58 y = 1,0,1,1,1,1,1,1,0,0,0,1,1,0,1,1,1,1,1,1,0,1,0 The values 70 and 75 are the two valuse for which there are both true and false observations, and are the cause of the slightly bad prediction of about 82%. The p70 and p75 variables in the model are the probability that the value of 70 and 75 are false, respectively. The breakpoint when false (0) switch to true (1) is when x is around 66 (the breakpoint variable). Cf my Gamble model gamble_logistic_regression_challenger.rkt This program was created by Hakan Kjellerstrand, hakank@gmail.com See also my Picat page: http://www.hakank.org/picat/ */ import ppl_distributions, ppl_utils. import util. % import ordset. main => go. /* Num accepted samples: 98 Total samples: 234756 (0.000%) Num accepted samples: 99 Total samples: 238244 (0.000%) Num accepted samples: 100 Total samples: 247103 (0.000%) var : w0 Probabilities (truncated): -0.089408783831073: 0.0100000000000000 -0.339511089667556: 0.0100000000000000 -0.733880710054389: 0.0100000000000000 -0.759768740332827: 0.0100000000000000 ......... -24.089793587294036: 0.0100000000000000 -24.484329896217631: 0.0100000000000000 -28.139701399030642: 0.0100000000000000 -31.195240506224202: 0.0100000000000000 mean = -10.1822 HPD intervals: HPD interval (0.84): -17.74130072940775..-0.95084722694918 var : w1 Probabilities (truncated): 0.470155550399948: 0.0100000000000000 0.436024473778669: 0.0100000000000000 0.362842137216864: 0.0100000000000000 0.359522608747916: 0.0100000000000000 ......... 0.032742033270729: 0.0100000000000000 0.032452960739099: 0.0100000000000000 0.020320069772799: 0.0100000000000000 0.017453333731018: 0.0100000000000000 mean = 0.163964 HPD intervals: HPD interval (0.84): 0.03245296073910..0.27137885027228 var : p70 Probabilities: false: 0.7900000000000000 true: 0.2100000000000000 mean = [false = 0.79,true = 0.21] HPD intervals: show_hpd_intervals: data is not numeric var : p75 Probabilities: false: 0.8600000000000000 true: 0.1400000000000000 mean = [false = 0.86,true = 0.14] HPD intervals: show_hpd_intervals: data is not numeric var : breakpoint Probabilities (truncated): 61: 0.0900000000000000 70: 0.0600000000000000 69: 0.0600000000000000 65: 0.0600000000000000 ......... 56: 0.0200000000000000 54: 0.0200000000000000 72: 0.0100000000000000 66: 0.0100000000000000 mean = 66.59 HPD intervals: HPD interval (0.84): 55.00000000000000..77.00000000000000 breakpoint_mean = 66.59 [i = 1,x = 66,y = 1,ok = false] [i = 2,x = 70,y = 0,ok = false] [i = 3,x = 69,y = 1,ok = true] [i = 4,x = 68,y = 1,ok = true] [i = 5,x = 67,y = 1,ok = true] [i = 6,x = 72,y = 1,ok = true] [i = 7,x = 73,y = 1,ok = true] [i = 8,x = 70,y = 1,ok = true] [i = 9,x = 57,y = 0,ok = true] [i = 10,x = 63,y = 0,ok = true] [i = 11,x = 70,y = 0,ok = false] [i = 12,x = 78,y = 1,ok = true] [i = 13,x = 67,y = 1,ok = true] [i = 14,x = 53,y = 0,ok = true] [i = 15,x = 67,y = 1,ok = true] [i = 16,x = 75,y = 1,ok = true] [i = 17,x = 70,y = 1,ok = true] [i = 18,x = 81,y = 1,ok = true] [i = 19,x = 76,y = 1,ok = true] [i = 20,x = 79,y = 1,ok = true] [i = 21,x = 75,y = 0,ok = false] [i = 22,x = 76,y = 1,ok = true] [i = 23,x = 58,y = 0,ok = true] [correct = 19,not_correct = 4,pct_correct = 0.826087] */ go ?=> Xs = [66,70,69,68,67,72,73,70,57,63,70,78,67,53,67,75,70,81,76,79,75,76,58], Ys = [1,0,1,1,1,1,1,1,0,0,0,1,1,0,1,1,1,1,1,1,0,1,0], println([len=Xs.len,xs_min=Xs.min,xs_mean=Xs.mean,xs_max=Xs.max,stdev=Xs.stdev]), reset_store, run_model(10_000,$model(Xs,Ys),[show_probs_trunc,mean, % show_percentiles,show_histogram, show_hpd_intervals,hpd_intervals=[0.84], min_accepted_samples=100,show_accepted_samples=true ]), nl, Breakpoints = get_store().get("breakpoint"), BreakpointMean=Breakpoints.mean, println(breakpoint_mean=BreakpointMean), Correct = 0, NotCorrect = 0, foreach(I in 1..Xs.len) X = Xs[I], Y = Ys[I], OK := true, if (X >= BreakpointMean, Y == 1) ; (X < BreakpointMean, Y == 0) then Correct := Correct + 1 else NotCorrect := NotCorrect + 1, OK := false end, println([i=I,x=X,y=Y,ok=OK]), end, println([correct=Correct,not_correct=NotCorrect,pct_correct=(Correct/Xs.len)]), % show_store_lengths,nl, % fail, nl. go => true. model(Xs,Ys) => Len = Xs.len, W0 = normal_dist(0,10), W1 = normal_dist(0,10), MinXs = Xs.min, MaxXs = Xs.max, % Restrict w0 to be negative and w1 positive. % This makes smaller xs values to be false and large to be true. observe(W0 < 0), observe(W1 > 0), Yss = [], foreach(I in 1..Len) X = Xs[I], Z = W0 + W1 * X, P = 1/(1+ exp(-Z)), % observe(bernoulli_dist(P) == Y) Yss := Yss ++ [bernoulli_dist(P)] end, observe_abc(Ys,Yss,1/10), % Find the breakpoint value when false switches to true % Note: this use the original data xs and ys. Breakpoint = uniform_draw(MinXs..MaxXs), % Note: Bs is only for presentation (post processing) purposes Bs = [], foreach(I in 1..Len) if Xs[I] < Breakpoint then Bs := Bs ++ [0] else Bs := Bs ++ [1] end end, % What is the probability that the values of 70 and 75 are false? P70 = check(false == flip( 1 / (1 + exp(- ( W0 + W1*70 ))))), P75 = check(false == flip( 1 / (1 + exp(- ( W0 + W1*75 ))))), if observed_ok then % println([yss=Yss,bs=Bs,breakpoint=Breakpoint]), add("w0",W0), add("w1",W1), add("p70",P70), add("p75",P75), add("breakpoint",Breakpoint), % add("bs",Bs) end.