Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def benchmark_pytorch_sru(batchsize, seq_length, feature_dimension, repeat=50):
with torch.cuda.device(args.gpu_device):
layer = SRUCell(feature_dimension, feature_dimension)
layer.cuda()
x_data = torch.autograd.Variable(torch.randn(seq_length, batchsize, feature_dimension).cuda())
# forward
start_time = time.time()
for i in range(repeat):
output, hidden = layer(x_data, None)
forward_time_mean = (time.time() - start_time) / repeat
# backward
start_time = time.time()
for i in range(repeat):
output, hidden = layer(x_data, None)
torch.sum(output).backward()
backward_time_mean = (time.time() - start_time) / repeat