Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _cuda_get_module(self):
if SRUFunction._cuda_module is not None:
return SRUFunction._cuda_module
SRUFunction._cuda_module = function.Module()
if cupy_version == 1:
SRUFunction._cuda_module.load(CUDA_SRU_PTX)
return SRUFunction._cuda_module
if cupy_version == 2:
ls = function.LinkState()
ls.add_ptr_data(CUDA_SRU_PTX, u"cupy.ptx")
SRUFunction._cuda_module.load(ls.complete())
return SRUFunction._cuda_module
raise NotImplementedError()
def sru(x, W, B, initial_ct, use_tanh=True, mask_x=None):
func = SRUFunction(use_tanh)
if mask_x is None:
return func(x, W, B, initial_ct)
return func(x, W, B, initial_ct, mask_x)