Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-31 00:22
import types
from typing import Callable, List, Generator, Union, Any, Tuple, Iterable
from hanlp.components.lambda_wrapper import LambdaComponent
from hanlp.common.component import Component
from hanlp.common.document import Document
from hanlp.utils.component_util import load_from_meta
from hanlp.utils.io_util import save_json, load_json
from hanlp.utils.reflection import module_path_of, str_to_type, class_path_of
import hanlp
class Pipe(Component):
def __init__(self, component: Component, input_key: str = None, output_key: str = None, **kwargs) -> None:
super().__init__()
self.output_key = output_key
self.input_key = input_key
self.component = component
self.kwargs = kwargs
self.meta.update({
'component': component.meta,
'input_key': self.input_key,
'output_key': self.output_key,
'kwargs': self.kwargs
})
# noinspection PyShadowingBuiltins
def predict(self, doc: Document, **kwargs) -> Document:
Parameters
----------
meta
kwargs
Returns
-------
Component
"""
cls = meta.get('class_path', None)
assert cls, f'{meta} doesn\'t contain class_path field'
cls = str_to_type(cls)
return cls.from_meta(meta)
class KerasComponent(Component, ABC):
def __init__(self, transform: Transform) -> None:
super().__init__()
self.model: Optional[tf.keras.Model] = None
self.config = SerializableDict()
self.transform = transform
# share config with transform for convenience, so we don't need to pass args around
if self.transform.config:
for k, v in self.transform.config.items():
self.config[k] = v
self.transform.config = self.config
def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, logger: logging.Logger = None,
callbacks: List[tf.keras.callbacks.Callback] = None, warm_up=True, verbose=True, **kwargs):
input_path = get_resource(input_path)
file_prefix, ext = os.path.splitext(input_path)
name = os.path.basename(file_prefix)
def insert(self, index: int, component: Callable, input_key: Union[str, Iterable[str]] = None,
output_key: Union[str, Iterable[str]] = None,
**kwargs):
if not input_key and len(self):
input_key = self[-1].output_key
if not isinstance(component, Component):
component = LambdaComponent(component)
super().insert(index, Pipe(component, input_key, output_key, **kwargs))
return self
else:
doc[self.output_key] = output
return doc
return output
def __repr__(self):
return f'{self.input_key}->{self.component.__class__.__name__}->{self.output_key}'
@staticmethod
def from_meta(meta: dict, **kwargs):
cls = str_to_type(meta['class_path'])
component = load_from_meta(meta['component'])
return cls(component, meta['input_key'], meta['output_key'], **meta['kwargs'])
class Pipeline(Component, list):
def __init__(self, *pipes: Pipe) -> None:
super().__init__()
if pipes:
self.extend(pipes)
def append(self, component: Callable, input_key: Union[str, Iterable[str]] = None,
output_key: Union[str, Iterable[str]] = None, **kwargs):
self.insert(len(self), component, input_key, output_key, **kwargs)
return self
def insert(self, index: int, component: Callable, input_key: Union[str, Iterable[str]] = None,
output_key: Union[str, Iterable[str]] = None,
**kwargs):
if not input_key and len(self):
input_key = self[-1].output_key
if not isinstance(component, Component):
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-31 18:36
from typing import Callable, Any
from hanlp.common.component import Component
from hanlp.utils.reflection import class_path_of, object_from_class_path, str_to_type
class LambdaComponent(Component):
def __init__(self, function: Callable) -> None:
super().__init__()
self.function = function
self.meta['function'] = class_path_of(function)
def predict(self, data: Any, **kwargs):
unpack = kwargs.pop('_hanlp_unpack', None)
if unpack:
return self.function(*data, **kwargs)
return self.function(data, **kwargs)
@staticmethod
def from_meta(meta: dict, **kwargs):
cls = str_to_type(meta['class_path'])
function = meta['function']
function = object_from_class_path(function)