% Uebung 21
%
% Trainieren diskreter HMMs mit dem Baum-Welch-Algorithmus
% -----------------------------------------------------------------
%
% Uebung 21.1
%
% Testen der Funktion fuer das Baum-Welch-Algorithmus
% ---------------------------------------------------
%
% Dieses Matlab-Skript dient zum Testen des Baum-Welch-Algorithmus, das 
% als Matlab-Funktion  [ns,nt,nb] = discr_baum_welch_alg(a,b,X)
% vorliegen muss, mit
% 
%   a      Zustandsuebergangswahrscheinlichkeiten des HMM mit
%            N Zust?nden (NxN-Matrix)
%   b      Beobachtungswahrscheinlichkeiten pro Zustand und 
%            diskrete Beobachtung 1..M (NxM-Matrix)
%   X      Trainingsfolge
%   ns     erwartete Anzahl Male in Zustaenden
%   nt     erwartete Anzahl Male der Zustandsuebergaenge
%   nb     erwartete Anzahle der Observationen in einem Zustand

% Testdaten:  HMM
                             % 3 Zust?nde mit 2 diskreten Beobachtungen
ddhmm(1).a = [0 1.0 0;  0 0.5 0.5;  0 0 0];
ddhmm(1).b = zeros(3,2); 
ddhmm(1).b(2,:) = [0.8 0.2];

                             % 3 Zust?nde mit 2 diskreten Beobachtungen
ddhmm(2).a = [0 1.0 0;  0 0.8 0.2;  0 0 0];
ddhmm(2).b = zeros(3,2); 
ddhmm(2).b(2,:) = [0.3 0.7];

                             % 3 Zust?nde mit 4 diskreten Beobachtungen
ddhmm(3).a = [0 1.0 0;  0 0.8 0.2;  0 0 0];
ddhmm(3).b = zeros(3,4);
ddhmm(3).b(2,:) = [0.3 0.7 0 0];

                             % 5 Zust?nde mit 2 diskreten Beobachtungen
ddhmm(4).a = [0 1.0 0 0 0;  0 0.3 0.7 0 0;  0 0 0.7 0.3 0; ...
              0 0 0 0.7 0.3;  0 0 0 0 0];
ddhmm(4).b = zeros(5,2);
ddhmm(4).b(2,:) = [0.5 0.5];
ddhmm(4).b(3,:) = [0.6 0.4];
ddhmm(4).b(4,:) = [0.2 0.8];

                         % erg. HMM, 4 Zust?nde mit 2 diskreten Beobachtungen
ddhmm(5).a = [0 0.1 0.9 0;  0 0 0.5 0.5;  0 0.5 0.3 0.2; 0 0 0 0];
ddhmm(5).b = zeros(4,2); 
ddhmm(5).b(2,:) = [0.5 0.5];
ddhmm(5).b(3,:) = [0.7 0.3];

                         % erg. HMM, 4 Zust?nde mit 3 diskreten Beobachtungen
ddhmm(6).a = [0 0.8 0.2 0;  0 0.5 0.3 0.2;  0 0.3 0.6 0.1;  0 0 0 0];
ddhmm(6).b = zeros(4,3); 
ddhmm(6).b(2,:) = [0.2 0.3 0.5];
ddhmm(6).b(3,:) = [0.3 0.4 0.3];


% Testdaten:  Beobachtungssequenzen
clear X

X{1}{1} = [1];                 % Beobachtungssequenz 1 f?r Modell 1
X{1}{2} = [1 1];               % Beobachtungssequenz 2 f?r Modell 1
X{1}{3} = [2 2 1];
X{1}{4} = [2 2 1 1 2];

X{2}{1} = [1];                 % Beobachtungssequenz 1 f?r Modell 2
X{2}{2} = [1 2];               % Beobachtungssequenz 2 f?r Modell 2
X{2}{3} = [2 2 1];
X{2}{4} = [2 2 2];

X{3}{1} = [1];
X{3}{2} = [1 2];

X{4}{1} = [1 1 1];
X{4}{2} = [1 2 1 1];

X{5}{1} = [1];
X{5}{2} = [1 1];
X{5}{3} = [1 2];
X{5}{4} = [1 1 1];

X{6}{1} = [1];
X{6}{2} = [1 2];
X{6}{3} = [1 2 3];
X{6}{4} = [1 1 1];

% Erwartete Ausgabe: 

out(1,1).ns = [ 1              1              1 ];
out(1,1).nt = [
0              1              0
0              0              1
0              0              0
];
out(1,1).nb = [
0              0
1              0
0              0
];
out(1,2).ns = [ 1              2              1 ];
out(1,2).nt = [
0              1              0
0              1              1
0              0              0
];
out(1,2).nb = [
0              0
2              0
0              0
];
out(1,3).ns = [ 1              3              1 ];
out(1,3).nt = [
0              1              0
0              2              1
0              0              0
];
out(1,3).nb = [
0              0
1              2
0              0
];
out(1,4).ns = [ 1              5              1 ];
out(1,4).nt = [
0              1              0
0              4              1
0              0              0
];
out(1,4).nb = [
0              0
2              3
0              0
];
out(2,1).ns = [ 1              1              1 ];
out(2,1).nt = [
0              1              0
0              0              1
0              0              0
];
out(2,1).nb = [
0              0
1              0
0              0
];
out(2,2).ns = [ 1              2              1 ];
out(2,2).nt = [
0              1              0
0              1              1
0              0              0
];
out(2,2).nb = [
0              0
1              1
0              0
];
out(2,3).ns = [ 1              3              1 ];
out(2,3).nt = [
0              1              0
0              2              1
0              0              0
];
out(2,3).nb = [
0              0
1              2
0              0
];
out(2,4).ns = [ 1              3              1 ];
out(2,4).nt = [
0              1              0
0              2              1
0              0              0
];
out(2,4).nb = [
0              0
0              3
0              0
];
out(3,1).ns = [ 1              1              1 ];
out(3,1).nt = [
0              1              0
0              0              1
0              0              0
];
out(3,1).nb = [
0              0              0              0
1              0              0              0
0              0              0              0
];
out(3,2).ns = [ 1              2              1 ];
out(3,2).nt = [
0              1              0
0              1              1
0              0              0
];
out(3,2).nb = [
0              0              0              0
1              1              0              0
0              0              0              0
];
out(4,1).ns = [ 1              1              1              1              1 ];
out(4,1).nt = [
0              1              0              0              0
0              0              1              0              0
0              0              0              1              0
0              0              0              0              1
0              0              0              0              0
];
out(4,1).nb = [
0              0
1              0
1              0
1              0
0              0
];
out(4,2).ns = [ 1      1.2866242      1.5350318      1.1783439              1 ];
out(4,2).nt = [
0              1              0              0              0
0      0.2866242              1              0              0
0              0     0.53503185              1              0
0              0              0     0.17834395              1
0              0              0              0              0
];
out(4,2).nb = [
         0              0
         1      0.2866242
0.82165605      0.7133758
 1.1783439              0
         0              0
];
out(5,1).ns = [ 1     0.16556291     0.83443709              1 ];
out(5,1).nt = [
0     0.16556291     0.83443709              0
0              0              0     0.16556291
0              0              0     0.83443709
0              0              0              0
];
out(5,1).nb = [
         0              0
0.16556291              0
0.83443709              0
         0              0
];
out(5,2).ns = [ 1     0.75660013      1.2433999              1 ];
out(5,2).nt = [
0     0.03219575     0.96780425              0
0              0     0.03219575     0.72440438
0     0.72440438     0.24339987     0.27559562
0              0              0              0
];
out(5,2).nb = [
         0              0
0.75660013              0
 1.2433999              0
         0              0
];
out(5,3).ns = [ 1     0.87618736      1.1238126              1 ];
out(5,3).nt = [
0    0.016377334     0.98362267              0
0              0    0.016377334     0.85981002
0     0.85981002     0.12381264     0.14018998
0              0              0              0
];
out(5,3).nb = [
          0              0
0.016377334     0.85981002
 0.98362267     0.14018998
          0              0
];
out(5,4).ns = [ 1     0.90652191      2.0934781              1 ];
out(5,4).nt = [
0    0.081086855     0.91891314              0
0              0     0.38698337     0.51953853
0     0.82543505     0.78758157     0.48046147
0              0              0              0
];
out(5,4).nb = [
         0              0
0.90652191              0
 2.0934781              0
         0              0
];
out(6,1).ns = [ 1     0.84210526     0.15789474              1 ];
out(6,1).nt = [
0     0.84210526     0.15789474              0
0              0              0     0.84210526
0              0              0     0.15789474
0              0              0              0
];
out(6,1).nb = [
         0              0              0
0.84210526              0              0
0.15789474              0              0
         0              0              0
];
out(6,2).ns = [ 1      1.3636364     0.63636364              1 ];
out(6,2).nt = [
0     0.72727273     0.27272727              0
0     0.51948052     0.20779221     0.63636364
0     0.11688312     0.15584416     0.36363636
0              0              0              0
];
out(6,2).nb = [
         0              0              0
0.72727273     0.63636364              0
0.27272727     0.36363636              0
         0              0              0
];
out(6,3).ns = [ 1      1.9568023      1.0431977              1 ];
out(6,3).nt = [
0     0.69833303     0.30166697              0
0     0.86216168     0.35436458     0.74027604
0     0.39630758     0.38716616     0.25972396
0              0              0              0
];
out(6,3).nb = [
         0              0              0
0.69833303     0.51819322     0.74027604
0.30166697     0.48180678     0.25972396
         0              0              0
];
out(6,4).ns = [ 1      1.6300211      1.3699789              1 ];
out(6,4).nt = [
0     0.67653277     0.32346723              0
0     0.64633041     0.45937783      0.5243129
0     0.30715796      0.5871338      0.4756871
0              0              0              0
];
out(6,4).nb = [
        0              0              0
1.6300211              0              0
1.3699789              0              0
        0              0              0
];

tolerance = 1e-6;                    
                    
for k = 1:length(ddhmm),
    for i = 1:length(X{k}),
        X_used = X{k}{i};       % fuer den Test gebrauchte Beobachtungssequenz
        out_ref = out(k,i);
        [ns,nt,nb] = discr_baum_welch_alg(ddhmm(k).a,ddhmm(k).b,X_used);
%        fprintf('out(%d,%d).ns = [ %s ]\n', k, i, num2str(ns,8));
%        fprintf('out(%d,%d).nt = [\n', k, i); disp(num2str(nt,8)); fprintf('];\n');
%        fprintf('out(%d,%d).nb = [\n', k, i); disp(num2str(nb,8)); fprintf('];\n');
        
        if (size(out_ref.ns, 1) ~= size(ns, 1)) || ...
            (size(out_ref.ns, 2) ~= size(ns, 2)) || ...
            any(abs((out_ref.ns - ns)./(out_ref.ns+(out_ref.ns == 0))) > tolerance)
              disp(['Fehler mit HMM ' num2str(k) ' und Beobachtungssequenz ' ...
                num2str(i) ':']);
              fprintf('ns vorgegeben: '); disp(out_ref.ns);
              fprintf('ns berechnet:  '); disp(ns);
              return;        
        end
        if (size(out_ref.nt, 1) ~= size(nt, 1)) || ...
            (size(out_ref.nt, 2) ~= size(nt, 2)) || ...
            any(any(abs((out_ref.nt - nt)./(out_ref.nt+(out_ref.nt == 0))) > tolerance))
              disp(['Fehler mit HMM ' num2str(k) ' und Beobachtungssequenz ' ...
                num2str(i) ':']);
              fprintf('nt vorgegeben:\n'); disp(out_ref.nt);
              fprintf('nt berechnet:\n'); disp(nt);
              return;        
        end

        if (size(out_ref.nb, 1) ~= size(nb, 1)) || ...
            (size(out_ref.nb, 2) ~= size(nb, 2)) || ...
            any(any(abs((out_ref.nb - nb)./(out_ref.nb+(out_ref.nb == 0))) > tolerance))
              disp(['Fehler mit HMM ' num2str(k) ' und Beobachtungssequenz ' ...
                num2str(i) ':']);
              fprintf('nb vorgegeben:\n'); disp(out_ref.nb);
              fprintf('nb berechnet:\n'); disp(nb);
              return;        
        end
    end
end
disp('Test ok');
