001package cnslab.image;
002import cnslab.cnsmath.*;
003public class Convolution
004{
005        public static int [][] convolve(int[][] initial, double [][] response)
006        {
007                //find out the proper padding
008                int m_mx = response.length/2;
009                int m_my = response[0].length/2;
010                int m_px = response.length - response.length/2;
011                int m_py = response[0].length - response[0].length/2;
012
013                int ix = initial.length;
014                int iy = initial[0].length;
015
016                int rx = response.length;
017                int ry = response[0].length;
018
019
020                boolean adjustX=true;//padding for x
021                boolean adjustY=true;//padding for y
022
023                int fullX=-1,fullY=-1;
024
025                for(int i=1;i<32;i++) //maximum size 2^32
026                {
027                        if(ix==(1<<i))
028                        {
029                                fullX = ix;
030                                adjustX=false;
031                                break;
032                        }
033                }
034                for(int i=1;i<32;i++)
035                {
036                        if(iy==(1<<i))
037                        {
038                                fullY = iy;
039                                adjustY=false;
040                                break;
041                        }
042                }
043
044
045                if(adjustX)
046                {
047                        for(int i=1;i<32;i++) //maximum size 2^32
048                        {
049                                if(ix+m_mx+m_px<=(1<<i))
050                                {
051                                        fullX = 1<<i;
052                                        break;
053                                }
054                        }  
055                }
056
057                if(adjustY)
058                {
059                        for(int i=1;i<32;i++) //maximum size 2^32
060                        {
061                                if(iy+m_my+m_py<=(1<<i))
062                                {
063                                        fullY = 1<<i;
064                                        break;
065                                }
066                        }  
067                }
068
069
070                //prepare matrix for FFT
071                double [] sig = new double[fullX*fullY*2]; 
072                double [] res = new double[fullX*fullY*2]; 
073        //      System.out.println(fullX+" "+fullY+" "+ix+" "+iy+" mx"+m_mx+" px"+m_px+" my"+m_my+" py"+m_py);
074
075                for(int i=0; i< ix; i++)
076                {
077                        for(int j=0; j< iy; j++)
078                        {
079                                sig[i*fullY*2+j*2] = (double)initial[i][j];
080                        }       
081                }
082
083                for(int i=0; i< m_mx+m_px; i++)
084                {
085                        for(int j=0; j< iy; j++)
086                        {
087                                sig[(i+fullX-m_mx-m_px)*fullY*2+j*2] = (double)initial[i+ix-m_mx-m_px][j];
088                        }       
089                }
090
091                for(int i=0; i< ix; i++)
092                {
093                        for(int j=0; j< m_my+m_py; j++)
094                        {
095                                sig[i*fullY*2+(j+fullY-m_my-m_py)*2] = (double)initial[i][j+iy-m_my-m_py];
096                        }       
097                }
098
099                for(int i=0; i< m_mx+m_px; i++)
100                {
101                        for(int j=0; j< m_my+m_py; j++)
102                        {
103                                sig[(i+fullX-m_mx-m_px)*fullY*2+(j+fullY-m_my-m_py)*2] = (double)initial[i+ix-m_mx-m_px][j+iy-m_my-m_py];
104                        }       
105                }
106
107                for(int i=0; i< m_px ; i++) //left  top
108                {
109                        for(int j=0; j < m_py; j++)
110                        {
111                                res[i*fullY*2+j*2] = response[i][j];
112                        }
113                }
114
115                for(int i=0; i< m_mx ; i++) // left bot
116                {
117                        for(int j=0; j < m_py; j++)
118                        {
119                                res[(i+fullX-m_mx)*fullY*2+j*2] = response[i+rx-m_mx][j];
120                        }
121                }
122
123                for(int i=0; i< m_px ; i++) //right top
124                {
125                        for(int j=0; j < m_my; j++)
126                        {
127                                res[i*fullY*2+(j+fullY-m_my)*2] = response[i][j+ry-m_my];
128                        }
129                }
130
131                for(int i=0; i< m_mx ; i++) //right bot 
132                {
133                        for(int j=0; j < m_my; j++)
134                        {
135                                res[(i+fullX-m_mx)*fullY*2+(j+fullY-m_my)*2] = response[i+rx-m_mx][j+ry-m_my];
136                        }
137                }
138
139                //fourier transferformation for both
140                int [] nn = new int [] {fullX,fullY};
141                FFT.fourn(sig, nn, 1);
142                FFT.fourn(res, nn, 1);
143
144                for(int i =0 ; i< fullX*fullY; i++)
145                {
146                        double tmpR = sig[2*i];
147                        sig[2*i] = sig[2*i]*res[2*i] - sig[2*i+1]*res[2*i+1];
148                        sig[2*i+1] = tmpR*res[2*i+1] + sig[2*i+1]*res[2*i];
149                }
150
151                //reverse transformation
152                FFT.fourn(sig, nn , -1);
153
154                int totalDim=fullX*fullY;
155
156
157                int [][] out = new int[initial.length][initial[0].length];
158
159                for(int i=0; i< ix - m_mx ; i++)
160                {
161                        for(int j=0; j< iy - m_my; j++)
162                        {
163                                out[i][j] = (int)Math.round(sig[i*fullY*2+ j*2]/(double)totalDim);
164                        }
165                }
166
167                for(int i=0; i< m_mx ; i++)
168                {
169                        for(int j=0; j< iy - m_my; j++)
170                        {
171                                out[i+ix-m_mx][j] = (int)Math.round(sig[(i+fullX-m_mx)*fullY*2+ j*2]/(double)totalDim);
172                        }
173                }
174
175                for(int i=0; i< ix - m_mx ; i++)
176                {
177                        for(int j=0; j< m_my; j++)
178                        {
179                                out[i][j+iy-m_my] = (int)Math.round(sig[i*fullY*2+ (j+fullY-m_my)*2]/(double)totalDim);
180                        }
181                }
182
183                for(int i=0; i< m_mx ; i++)
184                {
185                        for(int j=0; j< m_my; j++)
186                        {
187                                out[i+ix-m_mx][j+iy-m_my] = (int)Math.round(sig[(i+fullX-m_mx)*fullY*2+ (j+fullY-m_my)*2]/(double)totalDim);
188                        }
189                }
190                return out;
191        }
192
193        public static int [][] convolve_safe(int[][] initial, double [][] response)
194        {
195                int [][] out = new int[initial.length][initial[0].length];
196
197                //find out the proper padding
198                int m_mx = response.length/2;
199                int m_my = response[0].length/2;
200
201                for(int i=0; i<initial.length; i++)
202                {
203                        for(int j=0;j<initial[0].length;j++)
204                        {
205                                double sum=0.0;
206
207                                for(int m=0; m<response.length;m++)
208                                {
209                                        for(int n=0; n<response[0].length;n++)
210                                        {
211                                                int oldx=i+m-m_mx;
212                                                int oldy=j+n-m_my;
213                                                oldx = oldx % initial.length;
214                                                oldy = oldy % initial[0].length;
215                                                if(oldx<0)oldx+=initial.length;
216                                                if(oldy<0)oldy+=initial[0].length;
217                                                sum=sum+ response[response.length-1-m][response[0].length-1-n]*(double)initial[oldx][oldy];
218                                        }
219                                }       
220                                out[i][j]= (int)Math.round(sum);
221                        }
222                }
223                return out;
224        }
225}