How to use Huggingface GenerationMixin (or its beam search) with my own model?

1k views Asked by At

Huggingface's use of a mixin keeps teasing me that this should be possible, but I can't find any clear documentation on exactly what the requirements are, or if the dependencies are just too much to make it worth it. The central module is literally thousands and thousands of lines, and I felt from studying it yesterday that I've learnt more about how to write beam search than I have about GenerationMixin. :-)

From reading the source I think the dependencies are self.config then prepare_inputs_for_generation() and _update_model_kwargs_for_generation(); also implicitly forward(). But I'm not sure that is everything. Nor what each should look like. And I think it may expect forward() to return data in a specific format.

To make the discussion specific, and generally useful, how could Huggingface's beam search be used with minGPT, which has a forward() function that returns logits,loss. (It actually has its own generate() function that does the equivalent of Huggingface's sample() and greedy_search(), but no beam search support.) Or nanoGPT if you prefer - they are identical in this area.

In the comments I said It seems everyone's generate/beam search implementation is tied in closely with their transformer implementation... and I still can't really see why everyone reinvents this wheel, and why there is no standalone open source beam search implementation, with a clearly defined interface. Going to throw a bounty at this question, to see if it helps.

2

There are 2 answers

2
Viacheslav Ivannikov On

If you want to use huggingface code, what you're looking for is generate from GenerationMixin class, see here

So your options are either adapt the code to inherit from GenerationMixin, or copy the code over. Either way it depends on your model being huggingface-friendly so juts plugging in a random one without adjusting the code won't work.

If you don't necessarily want to use hface code, there's a bunch of very convenient implementations on github that are easier to adapt, for example here

0
Darkoob12 On

I have the same problem and I asked several chat bots but they generated non-sense answers. Apparently they cannot generalize beyond the Web.

By checking the code you would reach the can_generate method which has the following comment:

# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.

if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
   return False
return True

So, you have two options, overwrite either prepare_inputs_for_generation or generate.

The first option is suitable if you want to use one of the search strategies implemented in the GenerationMixin class and you have maybe a multi-modal model with different input requirement.

But if you have a different search strategy, like you want to generate XML and need to select next token in a way that ensures XML validity, you have to write your custom generate method. This is not a good design, there should have been a search or sample method that we could overwrite.

Also there is another abstraction LogitsProcessor which is used to change logit values before sampling. But the documentation is minimum and there are no example of custom usage.