3 Dimensional Matrix Multiplication

Open discussion for MAGMA library (Matrix Algebra on GPU and Multicore Architectures)
Post Reply
Volodimir
Posts: 10
Joined: Fri Jun 29, 2018 2:52 pm

3 Dimensional Matrix Multiplication

Post by Volodimir » Mon Aug 06, 2018 6:52 pm

I have 3 dimensional matrix S with dimensions [A, R, M], and i have 2 dimensional kernel K with dimensions [M, M].
I need to generate the output array Q with dimensions [A, R], and Qij element of this array is calculated as Sij: * K * Sij: ** H (Hermitian conjugated)
I can do this brute force by doing gemm on flattened S' with dimensions [A*R, M], and then extract diagonal elements from the resulting 2D array, but this is not efficient.
Any idea how this could be done better?
I understand this question is probably more appropriate to ask for NVIDIA and SO people, but wanted to start from here.
Thanks!

mgates3
Posts: 897
Joined: Fri Jan 06, 2012 2:13 pm

Re: 3 Dimensional Matrix Multiplication

Post by mgates3 » Mon Aug 06, 2018 11:54 pm

For concreteness, does the below Matlab code do what you want? If so, it can be done with a gemm + batched dot products. There may be something in MAGMA that can help with the batched dot products.

Code: Select all

% arbitrary dimensions
a = 2;
r = 3;
m = 4;

S = rand( a, r, m );
K = rand( m, m );

% naive implementation
Q = zeros( a, r );
for i = 1:a
	for j = 1:r
		x = reshape( S(i,j,:), [m,1] );
		Q(i,j) = x' * K * x;
	end
end

% gemm + dot products
Q2 = zeros( a*r, 1 );
S2 = reshape( S, [a*r, m] );
tmp = S2*K;  % gemm
for i = 1:a*r
	Q2(i) = tmp(i,:) * S2(i,:)';
end
Q2 = reshape( Q2, [a, r] );
-mark

Volodimir
Posts: 10
Joined: Fri Jun 29, 2018 2:52 pm

Re: 3 Dimensional Matrix Multiplication

Post by Volodimir » Tue Aug 14, 2018 12:16 am

thanks a lot!
Couple of questions:
for the first part, gemm side,

Code: Select all

tmp = S2*K;  % gemm
aren't we doing here a bunch of un-necessary processing?
for the batch dot product - forgive my ignorance please, but i do not see how to plug in batch processing instead of the loop

Code: Select all

for i = 1:a*r
   Q2(i) = tmp(i,:) * S2(i,:)';
end

mgates3
Posts: 897
Joined: Fri Jan 06, 2012 2:13 pm

Re: 3 Dimensional Matrix Multiplication

Post by mgates3 » Tue Aug 14, 2018 1:20 am

I think both the naive and the gemm implementation do the same computations, taking 2AR(M^2) + ARM - AR flops.

Is that the operation you are trying to do?

If you multiply (tmp * S2’), instead of the loop of dots, then you would be doing a lot of extra work — you only need the diagonal entries, as you said.

A batch dot would, for instance, take the two arrays tmp and S2 and dot all their rows (or cols). So a single function call would replace that loop over dots. However, after further investigating, it doesn’t appear that MAGMA currently has the batch dot that you would need. It shouldn’t be a hard kernel to write, though.

Volodimir
Posts: 10
Joined: Fri Jun 29, 2018 2:52 pm

Re: 3 Dimensional Matrix Multiplication

Post by Volodimir » Tue Aug 14, 2018 1:52 am

sorry, i do not see "obvious" way to write a kernel :(

mgates3
Posts: 897
Joined: Fri Jan 06, 2012 2:13 pm

Re: 3 Dimensional Matrix Multiplication

Post by mgates3 » Tue Aug 14, 2018 11:34 am

Here's a simple implementation. Only very lightly tested. Probably needs modifications for your specific environment and purposes, but should be a good starting point.
-mark
Attachments
batch_dot.tar.gz
(1.46 KiB) Downloaded 157 times

Post Reply