% Uebung 21
%
% Trainieren diskreter HMMs mit dem Baum-Welch-Algorithmus
% -----------------------------------------------------------------
%
% Uebung 21.2
%
% Testen der Funktion fuer das Baum-Welch-Training mit mehreren Sequenzen
% ----------------------------------------------------------------------
%
% Dieses Matlab-Skript dient zum Testen des Baum-Welch-Trainings, das 
% als Matlab-Funktion  [a,b] = discr_baum_welch_training(Xset,N,M,K)
% vorliegen muss, mit
% 
%   Xset   Satz von Trainingsfolgen (Matlab cell array) 
%   N      Anzahl Zustaende des HMM (inkl. Anfangs- und Endzustand)
%   M      Anzahl der diskreten Beobachtungen
%   K      Anzahl der Trainingsiterationen (im Baum-Welch-Training)
%   a      Zustandsuebergangswahrscheinlichkeiten des HMM mit
%            N Zustaenden (NxN-Matrix)
%   b      Beobachtungswahrscheinlichkeiten pro Zustand und 
%            diskrete Beobachtung 1..M (NxM-Matrix)


clear all


% Testdaten
% ----------

% Diskrete HMM 

ddhmm(1).name = 'mod1';
ddhmm(1).a = [0 1 0;  0 0 1;  0 0 0];
ddhmm(1).b = zeros(3,2);
ddhmm(1).b(2,:) = [0.5 0.5];

ddhmm(2).name = 'mod2';
ddhmm(2).a = [0 1 0;  0 0.5 0.5;  0 0 0];
ddhmm(2).b = zeros(3,2);
ddhmm(2).b(2,:) = [0.75 0.25];

ddhmm(3).name = 'mod3';
ddhmm(3).a = [0 1 0 0;  0 0.2 0.8 0;  0 0 0.2 0.8;  0 0 0 0];
ddhmm(3).b = zeros(4,2);
ddhmm(3).b(2,:) = [0.5 0.5];
ddhmm(3).b(3,:) = [0.5 0.5];


%----- Trainingsdaten -----

gen_tr_data = 0;          % falls =1, dann werden Trainingsdaten mit obigen 
Xset = {};                % HMM generiert, sonst die folgenden verwendet
if gen_tr_data ~= 1
                          % diese Trainingssets sind so gewaehlt, 
  Xset{1}{1} = [1];       % dass aus dem Training die vorgegebenen 
  Xset{1}{2} = [2];       % HMM-Parameter resultieren 
  
  Xset{2}{1} = [1];
  Xset{2}{2} = [2];
  Xset{2}{3} = [1 1 1];
  Xset{2}{4} = [1 2 1];

  Xset{3}{1} = [1 1];
  Xset{3}{2} = [2 2];
  Xset{3}{3} = [1 2 1];
  Xset{3}{4} = [2 1 2];

else 
  S = 1000;           % Anzahl Trainingssequenzen pro HMM
  rand('seed',0);
  for imod = 1:length(ddhmm)
    for s = 1:S 
      [Xset{imod}{s} Q] = gen_discr_obs_seq(ddhmm(imod));
    end
  end
end


%----- Testvorgaben -----

disp_res = 1;    % Resultate anzeigen

K = 20;          % Anzahl Trainingsiterationen (im Baum-Welch-Training)

tolerance = 0.05;    % Erlaubte Differenz zwischen den Parametern des
                     % trainierten und des vorgegebenen HMM


%----- Testlauf -----

disp(' ');
disp('Testen ...');

for imod = 1:length(ddhmm)
  tic
  disp(['Modell (',ddhmm(imod).name,')']);
  N = size(ddhmm(imod).a,1);
  M = size(ddhmm(imod).b,2);
  a1 = ddhmm(imod).a;
  b1 = ddhmm(imod).b;

%----- Berechnung (Baum-Welch-Training) -----

  [a2,b2] = discr_baum_welch_training(Xset{imod},N,M,K); 


%----- Test der Zustandsuebergangs-Wahrscheinlichkeiten -----

  delta = max(max(abs(a1-a2))); 
  if (delta > tolerance) & (imod ~= 4) | (disp_res == 1)
    disp(' ');
    disp('Uebergangswahrscheinlichkeiten (Soll): ');
    for j = 1:N, 
      for i = 1:N, fprintf(1,'  %6.4f',a1(j,i)); end;
      fprintf(1,'\n')
    end
    disp('Uebergangswahrscheinlichkeiten (Ist): ');
    for j = 1:N, 
      for i = 1:N, fprintf(1,'  %6.4f',a2(j,i)); end;
      fprintf(1,'\n')
    end
  end

  if (delta > tolerance) || any(isnan(a2(:)))
    disp(' ');
    disp(['Toleranz: ',num2str(tolerance)]);
    disp('Test fehlgeschlagen');
    disp(' ');
    return;
  end    

%----- Test Beobachtungswahrscheinlichkeiten -----

  delta = max(max(abs(b1-b2)));
  if isnan(b2(2,1)) | delta > tolerance | disp_res == 1 
    disp(' ');
    disp('Beobachtungswahrscheinlichkeiten (Soll): ');
    for j = 1:N, 
      for i = 1:M, fprintf(1,'  %6.4f',b1(j,i)); end;
      fprintf(1,'\n')
    end
    disp('Beobachtungswahrscheinlichkeiten (Ist): ');
    for j = 1:N, 
      for i = 1:M, fprintf(1,'  %6.4f',b2(j,i)); end;
      fprintf(1,'\n')
    end
  end

  if (delta > tolerance) || any(isnan(b2(:)))
    disp(' ');
    disp(['Toleranz: ',num2str(tolerance)]);
    disp('Test fehlgeschlagen');
    disp(' ');
    return;
  end
  tmUsed = toc;
  disp(['Rechenzeit: ',num2str(tmUsed),' s']);
end % (for)

disp('Test ok');
disp(' ');
format;
