CUV  0.9.201304091348
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Groups Pages
filter_factory.hpp
1 #ifndef __FILTER_FACTORY_HPP__
2 #define __FILTER_FACTORY_HPP__
3 
5 namespace cuv{
6  template<class T, class M, class I=unsigned int>
8  public:
9  typedef T value_type;
10  typedef M memory_space;
11  typedef I index_type;
12  public:
13  filter_factory(int px, int py, int fs, int input_maps, int output_maps)
14  : m_px(px)
15  , m_py(py)
16  , m_fs(fs)
17  , m_input_maps(input_maps)
18  , m_output_maps(output_maps)
19  {
20  }
21 
22  template<class M2>
24  extract_filter( const dia_matrix<T,M2>& dia, unsigned int filternumber){
25  tensor<T,M,row_major>* mat = new tensor<T,M,row_major>(extents[m_fs*m_fs][m_input_maps]);
26  fill(*mat, 0);
27  unsigned int map_size=dia.h()/m_input_maps;
28  for (unsigned int map_num = 0; map_num < m_input_maps; map_num++)
29  {
30  unsigned int fi = 0;
31  for (unsigned int i = 0; i < map_size; ++i)
32  {
33  if(!dia.has(i+map_num*map_size,filternumber))
34  continue;
35  mat[map_num * m_fs *m_fs + fi++]= dia(i+map_num*map_size,filternumber);
36  if(fi>=mat->size())
37  break;
38  }
39  }
40  return mat;
41  }
42 
44  get_dia(){
45  int fs = m_fs;
46  int nm = m_output_maps;
47  int msize = fs*fs*( nm + m_input_maps-1 );
48  int* off = new int[ msize ];
49  int offidx=0;
50  int dias_per_filter = 0;
52  m_px*m_py*m_input_maps,
53  m_px*m_py*m_output_maps,
54  msize,
55  std::max(m_px*m_py*m_output_maps,
56  m_px*m_py*m_input_maps));
57  for( int m=0;m<nm+m_input_maps-1;m++ ){
58  dias_per_filter = 0;
59  for( int i=0;i<fs;i++ ){
60  for( int j=0;j<fs;j++ ){
61  off[ offidx++ ] = i*m_px+j + m*m_px*m_py;
62  dias_per_filter ++;
63  }
64  }
65  }
66  cuvAssert( offidx == msize );
67 
68  m_dias_per_filter = dias_per_filter;
69 
70  for( int i=0;i<msize;i++ )
71  off[ i ] += -( m_px+1 )*( int( fs/2 ) ) - ( m_input_maps-1 ) * m_px*m_py;
72  tp->set_offsets( off, off+msize );
73 
74  delete off;
75  return tp;
76  }
77  private:
78  int m_px, m_py, m_fs, m_input_maps, m_output_maps;
79  int m_dias_per_filter;
80  };
81 }
82 
83 #endif /*__FILTER_FACTORY_HPP__ */