Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from contextvars import ContextVar
except ImportError:
pass
else:
import typing # pylint: disable=unused-import
from . import base_context
class AsyncRuntimeContext(base_context.BaseRuntimeContext):
class Slot(base_context.BaseRuntimeContext.Slot):
def __init__(self, name: str, default: object):
# pylint: disable=super-init-not-called
self.name = name
self.contextvar = ContextVar(name) # type: ContextVar[object]
self.default = base_context.wrap_callable(
default
) # type: typing.Callable[..., object]
def clear(self) -> None:
self.contextvar.set(self.default())
def get(self) -> object:
try:
return self.contextvar.get()
except LookupError:
value = self.default()
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import typing # pylint: disable=unused-import
from . import base_context
class ThreadLocalRuntimeContext(base_context.BaseRuntimeContext):
class Slot(base_context.BaseRuntimeContext.Slot):
_thread_local = threading.local()
def __init__(self, name: str, default: "object"):
# pylint: disable=super-init-not-called
self.name = name
self.default = base_context.wrap_callable(
default
) # type: typing.Callable[..., object]
def clear(self) -> None:
setattr(self._thread_local, self.name, self.default())
def get(self) -> "object":
try:
got = getattr(self._thread_local, self.name) # type: object
return got