RC_test1.m

haya

Uploaded on: Sept. 6, 2024, 8:44 p.m.
.other

clear;
rng('default')

N = 100; % reservoir neuron

inputScaling= 0.6; %入力強度
spectral_radius = 0.95; %スペクトル半径
sigma = 0;
alpha = 1.0;

%リザバー内部重みのスケーリング
W = (rand(N, N) * 2 - 1);
W = W * (spectral_radius / max(abs(eig(W))));

%入力
win = inputScaling * (rand(N, 1) * 2 - 1);

%% train

%Data length 
t = 1:1:600;  % Time vector from 0 to 1 second with 1 ms interval
ts = linspace(0, 2*pi, 600);
L = 10;  % Number of target signals
learn_every=2;
tjitter_i = 0;

trial = 100;
signal = 10; % 入力信号の立ち上がり

dl = multiple_InputTarget_pairs(L, t); %teacher signal
train_U = SimpleAttentionTask_noise_input(length(t),sigma,signal,tjitter_i);


x0 = 0.5*randn(N,1);
x = x0;
r = tanh(x);
nRec2Out = N;
wo = zeros(nRec2Out,1);
dw = zeros(nRec2Out,1);
P = (1.0/alpha)*eye(nRec2Out);

wo_len = zeros(1,length(t));    
zt = zeros(1,length(t));
et = zeros(1,length(t));

X = zeros(N,length(t));

for l = 1:L
    d = dl(l,:);
    for ti = 1:length(t)
        x = win * train_U(ti) +W * r;
        r = tanh(x);
        z = wo'*r;

        if mod(ti, learn_every) == 0
	        % update inverse correlation matrix
	        k = P*r;
	        rPr = r'*k;
	        c = 1.0/(1.0 + rPr);
	        P = P - k*(k'*c);
            
	        % update the error for the linear readout
	        e = z-d(ti);
	        
	        % update the output weights
	        dw = -e*k*c;	
	        wo = wo + dw;
            et(ti) = c;
        end
        % Store the output of the system.
        zt(ti) = z;
        wo_len(ti) = sqrt(wo'*wo);	
         X(:,ti) = r;
    end
end

figure;
subplot 411
plot(train_U);
title("input data")

subplot 412
% Select 10 random rows
element = 10;
random_rows = randperm(N, element);

hold on;
for i = 1:numel(random_rows)
    plot(X(random_rows(i), :));
end
hold off;
xlabel('Time');
ylabel('Value');
title(sprintf('Randomly Selected %d elements in reservoir (train)', element));

subplot 413
plot(d,'b','LineWidth',1.5);
hold on;
plot(zt,'r','LineWidth',1);
title('train result');
legend('teacher','output');
xlabel('Time');
ylabel('Vlaue');
hold off;

subplot 414
plot(wo_len);


MSE_train = immse(zt, d);
disp(['Training MSE: ' num2str(MSE_train, 4)]);

%% Plotting

%{
%% predict
x0 = zeros(N,1);
x = x0;
r = tanh(x);

X = zeros(N,tslen);

for test = 1:trial
    tjitter_i = randi([-2, 2]);
    Ts = tslen + washout;
    d = SimpleAttentionTask_teacher(Ts,react,tjitter_i);
    test_D = reshape(d(washout+1:end),[],1);

    idx = 1;
    test_U = SimpleAttentionTask_noise_input(Ts,sigma,signal,tjitter_i);
    for t = 1:Ts
        x = win*test_U(t) + W*r;
        r = tanh(x);

        if t > washout
        X(:,idx) = r;
        idx = idx + 1;
        end
    end
end

z_ts = wo*X;
z_ts = z_ts';

MSE = immse(z_ts,test_D);
disp(['predict MSE: ' num2str(MSE,4)]); 

%}