| /* |
| * Copyright (C) 2008 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| /* ---- includes ----------------------------------------------------------- */ |
| |
| #include "b_TensorEm/Int32Mat.h" |
| #include "b_TensorEm/Functions.h" |
| #include "b_BasicEm/Math.h" |
| #include "b_BasicEm/Functions.h" |
| #include "b_BasicEm/Memory.h" |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ auxiliary functions } ---------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| void bts_Int32Mat_reduceToNBits( int32* ptrA, uint32 sizeA, int32* bbpPtrA, uint32 nBitsA ) |
| { |
| int32 shiftL; |
| |
| /* find max element */ |
| int32 maxL = 0; |
| int32* ptrL = ptrA; |
| int32 iL = sizeA; |
| while( iL-- ) |
| { |
| int32 xL = *ptrL++; |
| if( xL < 0 ) xL = -xL; |
| if( xL > maxL ) maxL = xL; |
| } |
| |
| /* determine shift */ |
| shiftL = bts_absIntLog2( maxL ) + 1 - nBitsA; |
| |
| if( shiftL > 0 ) |
| { |
| ptrL = ptrA; |
| iL = sizeA; |
| while( iL-- ) |
| { |
| *ptrL = ( ( *ptrL >> ( shiftL - 1 ) ) + 1 ) >> 1; |
| ptrL++; |
| } |
| |
| *bbpPtrA -= shiftL; |
| } |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ constructor / destructor } ----------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| void bts_Int32Mat_init( struct bbs_Context* cpA, |
| struct bts_Int32Mat* ptrA ) |
| { |
| ptrA->widthE = 0; |
| bbs_Int32Arr_init( cpA, &ptrA->arrE ); |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| void bts_Int32Mat_exit( struct bbs_Context* cpA, |
| struct bts_Int32Mat* ptrA ) |
| { |
| ptrA->widthE = 0; |
| bbs_Int32Arr_exit( cpA, &ptrA->arrE ); |
| } |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ operators } -------------------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ query functions } -------------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ modify functions } ------------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| void bts_Int32Mat_create( struct bbs_Context* cpA, |
| struct bts_Int32Mat* ptrA, |
| int32 widthA, |
| struct bbs_MemSeg* mspA ) |
| { |
| if( bbs_Context_error( cpA ) ) return; |
| bbs_Int32Arr_create( cpA, &ptrA->arrE, widthA * widthA, mspA ); |
| ptrA->widthE = widthA; |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| void bts_Int32Mat_copy( struct bbs_Context* cpA, |
| struct bts_Int32Mat* ptrA, |
| const struct bts_Int32Mat* srcPtrA ) |
| { |
| if( ptrA->widthE != srcPtrA->widthE ) |
| { |
| bbs_ERROR0( "void bts_Int32Mat_copy( struct bts_Int32Mat* ptrA, struct bts_Int32Mat* srcPtrA ):\n" |
| "size mismatch" ); |
| return; |
| } |
| |
| bbs_Int32Arr_copy( cpA, &ptrA->arrE, &srcPtrA->arrE ); |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ I/O } -------------------------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| uint32 bts_Int32Mat_memSize( struct bbs_Context* cpA, |
| const struct bts_Int32Mat *ptrA ) |
| { |
| return bbs_SIZEOF16( uint32 ) |
| + bbs_SIZEOF16( uint32 ) /* version */ |
| + bbs_SIZEOF16( ptrA->widthE ) |
| + bbs_Int32Arr_memSize( cpA, &ptrA->arrE ); |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| uint32 bts_Int32Mat_memWrite( struct bbs_Context* cpA, |
| const struct bts_Int32Mat* ptrA, |
| uint16* memPtrA ) |
| { |
| uint32 memSizeL = bts_Int32Mat_memSize( cpA, ptrA ); |
| memPtrA += bbs_memWrite32( &memSizeL, memPtrA ); |
| memPtrA += bbs_memWriteUInt32( bts_INT32MAT_VERSION, memPtrA ); |
| memPtrA += bbs_memWrite32( &ptrA->widthE, memPtrA ); |
| memPtrA += bbs_Int32Arr_memWrite( cpA, &ptrA->arrE, memPtrA ); |
| return memSizeL; |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| uint32 bts_Int32Mat_memRead( struct bbs_Context* cpA, |
| struct bts_Int32Mat* ptrA, |
| const uint16* memPtrA, |
| struct bbs_MemSeg* mspA ) |
| { |
| uint32 memSizeL, versionL; |
| if( bbs_Context_error( cpA ) ) return 0; |
| memPtrA += bbs_memRead32( &memSizeL, memPtrA ); |
| memPtrA += bbs_memReadVersion32( cpA, &versionL, bts_INT32MAT_VERSION, memPtrA ); |
| memPtrA += bbs_memRead32( &ptrA->widthE, memPtrA ); |
| memPtrA += bbs_Int32Arr_memRead( cpA, &ptrA->arrE, memPtrA, mspA ); |
| |
| if( memSizeL != bts_Int32Mat_memSize( cpA, ptrA ) ) |
| { |
| bbs_ERR0( bbs_ERR_CORRUPT_DATA, "uint32 bts_Int32Mat_memRead( const struct bts_Int32Mat* ptrA, const void* memPtrA ):\n" |
| "size mismatch" ); |
| } |
| return memSizeL; |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| /* */ |
| /* ---- \ghd{ exec functions } --------------------------------------------- */ |
| /* */ |
| /* ========================================================================= */ |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| flag bts_Int32Mat_solve( struct bbs_Context* cpA, |
| const int32* matA, |
| int32 matWidthA, |
| const int32* inVecA, |
| int32* outVecA, |
| int32 bbpA, |
| int32* tmpMatA, |
| int32* tmpVecA ) |
| { |
| bbs_memcpy32( tmpMatA, matA, ( matWidthA * matWidthA ) * bbs_SIZEOF32( int32 ) ); |
| |
| return bts_Int32Mat_solve2( cpA, |
| tmpMatA, |
| matWidthA, |
| inVecA, |
| outVecA, |
| bbpA, |
| tmpVecA ); |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| flag bts_Int32Mat_solve2( struct bbs_Context* cpA, |
| int32* matA, |
| int32 matWidthA, |
| const int32* inVecA, |
| int32* outVecA, |
| int32 bbpA, |
| int32* tmpVecA ) |
| { |
| int32 sizeL = matWidthA; |
| int32 bbpL = bbpA; |
| int32 iL, jL, kL; |
| int32 iPivL; |
| int32 jPivL; |
| |
| int32* vecL = outVecA; |
| int32* matL = matA; |
| int32* checkArrL = tmpVecA; |
| |
| for( iL = 0; iL < sizeL; iL++ ) |
| { |
| checkArrL[ iL ] = 0; |
| } |
| |
| bbs_memcpy32( outVecA, inVecA, sizeL * bbs_SIZEOF32( int32 ) ); |
| |
| iPivL = 0; |
| |
| for( kL = 0; kL < sizeL; kL++ ) |
| { |
| /* find pivot */ |
| int32 maxAbsL = 0; |
| int32* pivRowL; |
| |
| int32 bbp_pivRowL, bbp_vecL, shiftL; |
| |
| jPivL = -1; |
| for( iL = 0; iL < sizeL; iL++ ) |
| { |
| if( checkArrL[ iL ] != 1 ) |
| { |
| int32* rowL = matL + ( iL * sizeL ); |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| if( checkArrL[ jL ] == 0 ) |
| { |
| int32 absElemL = rowL[ jL ]; |
| if( absElemL < 0 ) absElemL = -absElemL; |
| if( maxAbsL < absElemL ) |
| { |
| maxAbsL = absElemL; |
| iPivL = iL; |
| jPivL = jL; |
| } |
| } |
| else if( checkArrL[ jL ] > 1 ) |
| { |
| return FALSE; |
| } |
| } |
| } |
| } |
| |
| /* successfull ? */ |
| if( jPivL < 0 ) |
| { |
| return FALSE; |
| } |
| |
| checkArrL[ jPivL ]++; |
| |
| /* exchange rows to put pivot on diagonal, if neccessary */ |
| if( iPivL != jPivL ) |
| { |
| int32* row1PtrL = matL + ( iPivL * sizeL ); |
| int32* row2PtrL = matL + ( jPivL * sizeL ); |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| int32 tmpL = *row1PtrL; |
| *row1PtrL++ = *row2PtrL; |
| *row2PtrL++ = tmpL; |
| } |
| |
| { |
| int32 tmpL = vecL[ jPivL ]; |
| vecL[ jPivL ] = vecL[ iPivL ]; |
| vecL[ iPivL ] = tmpL; |
| } |
| } |
| /* now index jPivL specifies pivot row and maximum element */ |
| |
| |
| /** Overflow protection: only if the highest bit of the largest matrix element is set, |
| * we need to shift the whole matrix and the right side vector 1 bit to the right, |
| * to make sure there can be no overflow when the pivot row gets subtracted from the |
| * other rows. |
| * Getting that close to overflow is a rare event, so this shift will happen only |
| * occasionally, or not at all. |
| */ |
| if( maxAbsL & 1073741824 ) /*( 1 << 30 )*/ |
| { |
| /* right shift matrix by 1 */ |
| int32 iL = sizeL * sizeL; |
| int32* ptrL = matL; |
| while( iL-- ) |
| { |
| *ptrL = ( *ptrL + 1 ) >> 1; |
| ptrL++; |
| } |
| |
| /* right shift right side vector by 1 */ |
| iL = sizeL; |
| ptrL = vecL; |
| while( iL-- ) |
| { |
| *ptrL = ( *ptrL + 1 ) >> 1; |
| ptrL++; |
| } |
| |
| /* decrement bbpL */ |
| bbpL--; |
| } |
| |
| |
| /* reduce elements of pivot row to 15 bit */ |
| pivRowL = matL + jPivL * sizeL; |
| bbp_pivRowL = bbpL; |
| bts_Int32Mat_reduceToNBits( pivRowL, sizeL, &bbp_pivRowL, 15 ); |
| |
| /* scale pivot row such that maximum equals 1 */ |
| { |
| int32 maxL = pivRowL[ jPivL ]; |
| int32 bbp_maxL = bbp_pivRowL; |
| int32 factorL = 1073741824 / maxL; /*( 1 << 30 )*/ |
| |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| pivRowL[ jL ] = ( pivRowL[ jL ] * factorL + ( 1 << 14 ) ) >> 15; |
| } |
| bbp_pivRowL = 15; |
| |
| /* set to 1 to avoid computational errors */ |
| pivRowL[ jPivL ] = ( int32 )1 << bbp_pivRowL; |
| |
| shiftL = 30 - bts_absIntLog2( vecL[ jPivL ] ); |
| |
| vecL[ jPivL ] = ( vecL[ jPivL ] << shiftL ) / maxL; |
| bbp_vecL = bbpL + shiftL - bbp_maxL; |
| |
| bbs_int32ReduceToNBits( &( vecL[ jPivL ] ), &bbp_vecL, 15 ); |
| } |
| |
| /* subtract pivot row from all other rows */ |
| for( iL = 0; iL < sizeL; iL++ ) |
| { |
| if( iL != jPivL ) |
| { |
| int32* rowPtrL = matL + iL * sizeL; |
| |
| int32 tmpL = *( rowPtrL + jPivL ); |
| int32 bbp_tmpL = bbpL; |
| bbs_int32ReduceToNBits( &tmpL, &bbp_tmpL, 15 ); |
| |
| shiftL = bbp_tmpL + bbp_pivRowL - bbpL; |
| if( shiftL > 0 ) |
| { |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| *rowPtrL++ -= ( ( ( tmpL * pivRowL[ jL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1; |
| } |
| } |
| else |
| { |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| *rowPtrL++ -= ( tmpL * pivRowL[ jL ] ) << -shiftL; |
| } |
| } |
| |
| shiftL = bbp_tmpL + bbp_vecL - bbpL; |
| if( shiftL > 0 ) |
| { |
| vecL[ iL ] -= ( ( ( tmpL * vecL[ jPivL ] ) >> ( shiftL - 1 ) ) + 1 ) >> 1; |
| } |
| else |
| { |
| vecL[ iL ] -= ( tmpL * vecL[ jPivL ] ) << -shiftL; |
| } |
| } |
| } |
| |
| /* change bbp of pivot row back to bbpL */ |
| shiftL = bbpL - bbp_pivRowL; |
| if( shiftL >= 0 ) |
| { |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| pivRowL[ jL ] <<= shiftL; |
| } |
| } |
| else |
| { |
| shiftL = -shiftL; |
| for( jL = 0; jL < sizeL; jL++ ) |
| { |
| pivRowL[ jL ] = ( ( pivRowL[ jL ] >> ( shiftL - 1 ) ) + 1 ) >> 1; |
| } |
| } |
| |
| shiftL = bbpL - bbp_vecL; |
| if( shiftL >= 0 ) |
| { |
| vecL[ jPivL ] <<= shiftL; |
| } |
| else |
| { |
| shiftL = -shiftL; |
| vecL[ jPivL ] = ( ( vecL[ jPivL ] >> ( shiftL - 1 ) ) + 1 ) >> 1; |
| } |
| /* |
| if( sizeL <= 5 ) bts_Int32Mat_print( matL, vecL, sizeL, bbpL ); |
| */ |
| } /* of kL */ |
| |
| /* in case bbpL has been decreased by the overflow protection, change it back now */ |
| if( bbpA > bbpL ) |
| { |
| /* find largest element of solution vector */ |
| int32 maxL = 0; |
| int32 iL, shiftL; |
| for( iL = 0; iL < sizeL; iL++ ) |
| { |
| int32 xL = vecL[ iL ]; |
| if( xL < 0 ) xL = -xL; |
| if( xL > maxL ) maxL = xL; |
| } |
| |
| /* check whether we can left shift without overflow */ |
| shiftL = 30 - bts_absIntLog2( maxL ); |
| if( shiftL < ( bbpA - bbpL ) ) |
| { |
| /* |
| bbs_WARNING1( "flag bts_Int32Mat_solve2( ... ): getting overflow when trying to " |
| "compute solution vector with bbp = %d. Choose smaller bbp.\n", bbpA ); |
| */ |
| |
| return FALSE; |
| } |
| |
| /* shift left */ |
| shiftL = bbpA - bbpL; |
| for( iL = 0; iL < sizeL; iL++ ) vecL[ iL ] <<= shiftL; |
| } |
| |
| return TRUE; |
| } |
| |
| /* ------------------------------------------------------------------------- */ |
| |
| /* ========================================================================= */ |
| |