Skip to content

Commit c406bf0

Browse files
authored
Merge branch 'main' into dev/add_duckdb
Signed-off-by: Bibo Hao <haobibo@users.noreply.github.com>
2 parents c76875b + 83c6693 commit c406bf0

18 files changed

Lines changed: 369 additions & 44 deletions

File tree

.github/workflows/pip.yml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# Secret Variables required in GitHub secrets: TWINE_USERNAME, TWINE_PASSWORD / TWINE_USERNAME_TEST, TWINE_PASSWORD_TEST
22

3-
name: build
3+
name: build-pip-publish
44

5-
# Controls when the action will run.
65
on:
7-
# Triggers the workflow on push or pull request events but only for the main branch
86
push:
97
branches: [ main ]
108
paths-ignore: [ "*.md" ]
@@ -26,7 +24,7 @@ jobs:
2624
steps:
2725
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
2826
# sudo python setup.py install clean --all
29-
- uses: actions/checkout@v3
27+
- uses: actions/checkout@v4
3028

3129
- name: pip-install-test
3230
run: |
@@ -48,7 +46,7 @@ jobs:
4846
sudo python3 -c "import fcntl; fcntl.fcntl(1, fcntl.F_SETFL, 0)"
4947
sudo python3 setup.py sdist bdist_wheel
5048
ls -alh ./dist
51-
if [ "${GITHUB_REPOSITORY}" = "QPod/aloha" ] && [ "${GITHUB_REF_NAME}" = "main" ] ; then
49+
if [ "${GITHUB_REPOSITORY}" = "QPod/aloha-python" ] && [ "${GITHUB_REF_NAME}" = "main" ] ; then
5250
twine upload dist/* --verbose -u "${TWINE_USERNAME}" -p "${TWINE_PASSWORD}" ;
5351
elif [ ! -z "${TWINE_USERNAME_TEST}" ]; then
5452
twine upload dist/* --verbose -u "${TWINE_USERNAME_TEST}" -p "${TWINE_PASSWORD_TEST}" \

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Aloha!
22

33
[![License](https://img.shields.io/github/license/QPod/aloha)](https://github.com/QPod/aloha/blob/main/LICENSE)
4-
[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/QPod/aloha/build)](https://github.com/QPod/aloha/actions)
4+
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/QPod/aloha-python/pip.yml?branch=main)](https://github.com/QPod/aloha-python/actions)
55
[![Join the Gitter Chat](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/QPod/)
66
[![PyPI version](https://img.shields.io/pypi/v/aloha)](https://pypi.python.org/pypi/aloha/)
77
[![PyPI Downloads](https://img.shields.io/pypi/dm/aloha)](https://pepy.tech/badge/aloha/)
@@ -21,6 +21,6 @@ Please generously STAR★ our project or donate to us! [![GitHub Starts](https:
2121

2222
## Getting started
2323

24-
```py
24+
```shell
2525
pip install aloha[all]
2626
```

demo/app_common/ainlp/__init__.py

Whitespace-only changes.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import List
2+
3+
import torch
4+
from transformers import AutoTokenizer, AutoModel
5+
6+
from aloha.service.streamer import ManagedModel
7+
8+
SEED = 0
9+
torch.manual_seed(SEED)
10+
torch.cuda.manual_seed(SEED)
11+
12+
13+
class TextUnmaskModel:
14+
def __init__(self, max_sent_len=16, model_path="bert-base-uncased"):
15+
self.model_path = model_path
16+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
17+
self.transformer = AutoModel.from_pretrained(self.model_path)
18+
self.transformer.eval()
19+
self.transformer.to(device="cuda")
20+
self.max_sent_len = max_sent_len
21+
22+
def predict(self, batch: List[str]) -> List[str]:
23+
"""predict masked word"""
24+
batch_inputs = []
25+
masked_indexes = []
26+
27+
for text in batch:
28+
tokenized_text = self.tokenizer.tokenize(text)
29+
if len(tokenized_text) > self.max_sent_len - 2:
30+
tokenized_text = tokenized_text[: self.max_sent_len - 2]
31+
32+
tokenized_text = ['[CLS]'] + tokenized_text + ['[SEP]']
33+
tokenized_text += ['[PAD]'] * (self.max_sent_len - len(tokenized_text))
34+
35+
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text)
36+
batch_inputs.append(indexed_tokens)
37+
masked_indexes.append(tokenized_text.index('[MASK]'))
38+
39+
tokens_tensor = torch.tensor(batch_inputs).to("cuda")
40+
41+
with torch.no_grad():
42+
# prediction_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
43+
prediction_scores = self.transformer(tokens_tensor)[0]
44+
45+
batch_outputs = []
46+
for i in range(len(batch_inputs)):
47+
predicted_index = torch.argmax(prediction_scores[i, masked_indexes[i]]).item()
48+
predicted_token = self.tokenizer.convert_ids_to_tokens(predicted_index)
49+
batch_outputs.append(predicted_token)
50+
51+
return batch_outputs
52+
53+
54+
class ManagedBertModel(ManagedModel):
55+
def init_model(self):
56+
self.model = TextUnmaskModel()
57+
58+
def predict(self, batch):
59+
return self.model.predict(batch)
60+
61+
62+
def test_simple():
63+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
64+
model = AutoModel.from_pretrained("bert-base-uncased")
65+
inputs = tokenizer("Hello! My name is [MASK]!", return_tensors="pt")
66+
outputs = model(**inputs)
67+
print(outputs)
68+
69+
predicted_index = torch.argmax(outputs[1]).item()
70+
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
71+
print(predicted_token)
72+
73+
74+
def test_batch():
75+
batch_text = [
76+
"twinkle twinkle [MASK] star.",
77+
"Happy birthday to [MASK].",
78+
'the answer to life, the [MASK], and everything.'
79+
]
80+
model = TextUnmaskModel()
81+
outputs = model.predict(batch_text)
82+
print(outputs)
83+
84+
85+
if __name__ == "__main__":
86+
test_simple()

demo/app_common/ainlp/test-gpu-async.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from aloha.logger import LOG
2+
from aloha.service.api.v0 import APIHandler
3+
4+
5+
class MultipartHandler(APIHandler):
6+
def response(self, params=None, *args, **kwargs):
7+
LOG.debug(params)
8+
return params
9+
10+
11+
default_handlers = [
12+
# internal API: QueryDB Postgres with sql directly
13+
(r"/api_internal/multipart", MultipartHandler),
14+
]

demo/app_common/debug.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ def main():
66
modules_to_load = [
77
"app_common.api.api_common_sys_info",
88
"app_common.api.api_common_query_postgres",
9+
"app_common.api.api_multipart",
910
]
1011

1112
if 'service' not in SETTINGS.config:

src/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Aloha!
22

33
[![License](https://img.shields.io/github/license/QPod/aloha)](https://github.com/QPod/aloha/blob/main/LICENSE)
4-
[![GitHub Workflow Status](https://img.shields.io/github/workflow/status/QPod/aloha/build)](https://github.com/QPod/aloha/actions)
4+
[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/QPod/aloha-python/pip.yml?branch=main)](https://github.com/QPod/aloha-python/actions)
55
[![Join the Gitter Chat](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/QPod/)
66
[![PyPI version](https://img.shields.io/pypi/v/aloha)](https://pypi.python.org/pypi/aloha/)
77
[![PyPI Downloads](https://img.shields.io/pypi/dm/aloha)](https://pepy.tech/badge/aloha/)
@@ -21,6 +21,6 @@ Please generously STAR★ our project or donate to us! [![GitHub Starts](https:
2121

2222
## Getting started
2323

24-
```py
24+
```shell
2525
pip install aloha[all]
2626
```

src/aloha/config/paths.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = ('get_resource_dir', 'get_config_dir', 'get_current_module_dir', 'get_project_base_dir', 'path_join')
22

33
import os
4+
import sys
45
import warnings
56

67

@@ -24,7 +25,6 @@ def get_config_dir(*args) -> str:
2425
if dir_config is None or len(dir_config.strip()) == 0:
2526
dir_config = 'config'
2627
dir_config = path_join(dir_resource, dir_config, *args)
27-
# print(' ---> Using config dir:', dir_config)
2828
return dir_config
2929

3030

@@ -48,15 +48,19 @@ def get_config_files() -> list:
4848

4949
files = files_config.split(',')
5050
ret = []
51+
msgs = []
5152
for f in files:
5253
file = get_config_dir(f)
5354
if not os.path.exists(file):
54-
warnings.warn('Expecting config file [%s] but it does not exists!' % file)
55+
msgs.append('Expecting config file [%s] but it does not exists!' % file)
5556
else:
56-
print(' ---> Loading config file [%s]' % file)
57+
print(' ---> Loading config file [%s]' % file, file=sys.stderr)
5758
ret.append(os.path.expandvars(f))
5859
if len(ret) == 0:
59-
warnings.warn('No config files set properly, EMPTY config will be used!')
60+
msgs.append('No config files set properly, EMPTY config will be used!')
61+
62+
if len(msgs) > 0:
63+
warnings.warn('\n'.join(msgs))
6064
return ret
6165

6266

src/aloha/db/mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, db_config, **kwargs):
2424
try:
2525
self.db = create_engine(
2626
'mysql+pymysql://{user}:{password}@{host}:{port}/{dbname}'.format(**self._config),
27-
encoding='utf-8', pool_size=50, pool_recycle=500, pool_pre_ping=True, **kwargs
27+
pool_size=50, pool_recycle=500, pool_pre_ping=True, **kwargs
2828
)
2929
LOG.debug("MySQL connected: {host}:{port}/{dbname}".format(**self._config))
3030
except Exception as e:

0 commit comments

Comments
 (0)