Comments (4)
It is ignored because you define load
in your inference.py
and not as documented here load_fn
here. That's why the toolkit uses the load
function of the HuggingFaceHandlerService
.
Additionally, I think the model
in your predict
might not be used since currently, the HuggingFaceHandlerService
uses self.model
in the predict
function.
from sagemaker-huggingface-inference-toolkit.
ok I'm lost :)
how should we name the functions in inference.py? (also asked here #6)
- load_fn, preprocess_fn, predict, postprocess (what's currently in the doc)
- load_fn, preprocess_fn, predict_fn, postprocess_fn
I understand that it will soon switch to model_fn, input_fn, predict_fn, output_fn ; but I'm curious about right now
from sagemaker-huggingface-inference-toolkit.
Here is the example we use for the tests. Yes, the documentation has an issue. It should be _fn
everywhere.
and here is it where they got loaded
from sagemaker-huggingface-inference-toolkit.
this is the final setup that worked ; as mentioned in the doc the contract now is model_fn
, input_fn
, predict_fn
, output_fn
import logging
import os
import tensorflow as tf
from transformers import TFAutoModelForQuestionAnswering, AutoTokenizer
logging.basicConfig(level=logging.INFO)
def model_fn(model_dir):
"""this function reads the model from disk"""
logging.info('model_fn dir view:')
logging.info(os.listdir())
# load model
transformer = TFAutoModelForQuestionAnswering.from_pretrained(model_dir)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return transformer, tokenizer
def predict_fn(processed_data, model):
"""this function runs inference"""
transformer, tokenizer = model
question, text = processed_data['question'], processed_data['context']
logging.info('processed_data received: {}'.format(processed_data))
# infer
input_dict = tokenizer(question, text, return_tensors='tf')
outputs = transformer(input_dict)
# post processing
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])
return answer
from sagemaker-huggingface-inference-toolkit.
Related Issues (20)
- Using custom inference script and models from Hub HOT 1
- get_pipeline function passes Path object rather than PretrainedTokenizer
- No support for multi-GPU HOT 2
- 🏷️ invalid
- Sagemaker endpoint inferencing error with HF model loading from s3bucket with new transformer update HOT 5
- Support multiple return sequences
- HF_TASK Enviournment Variable error HOT 1
- Endpoint creation completes before custom model_fn finishes loading resources
- ARCHITECTURES_2_TASK is limiting the tasks able to be deployed with HF DLC HOT 11
- Make DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY configurable
- InternalServerException at runtime HOT 3
- trust_remote_code=True in new Hugging Face LLM Inference Container for Amazon SageMaker HOT 2
- How to access CustomAttributes in async inferece request input_fn HOT 1
- [DOCS] List of available HF_TASK and default inference scripts HOT 4
- Dead Link for Available HF_Tasks HOT 1
- SageMaker deployment errors HOT 2
- Error on Sagemaker deployment for v1.0.1 HOT 1
- How can I delpoy a model with AWS S3 and without downloading model from hunggingface via TGI image on Sagemaker? HOT 2
- How to enable Batch inference on AWS deployed Serverless model from Hub? HOT 1
- Where is the logic for detecting custom inference.py? HOT 6
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from sagemaker-huggingface-inference-toolkit.