Error: fit() takes 3 positional arguments but 4 were given
Greetings!
Currently, the flow to fit a model looks like this:
- We do have a custom backend used to fit the model:
from omegaml.backends.basemodel import BaseModelBackend from sdv.tabular import CTGAN ... class CTGANBackend(BaseModelBackend): """ OmegaML backend to use with Pytorch """ KIND = 'ctgan.pt' @classmethod def supports(cls, obj, name, **kwargs): return isinstance(obj, CTGAN) def fit(self, modelname, Xname, **kwargs): """ SDV documentation: CTGAN.fit() https://sdv.dev/SDV/user_guides/single_table/ctgan.html#quick-usage CTGAN._fit(): https://sdv.dev/SDV/developer_guides/sdv/tabular.html?highlight=discrete_columns#fit-method """ X = self.data_store.get(Xname) model = self.model_store.get(modelname) model.fit(X) result = self.model_store.put(model, modelname) return result ...
As you see, here there are 3 arguments in the fit() method (including self).
All the backends were imported into the project without any issue.
2.We define a pipeline as:
pipeline = CTGAN(epochs=epochs) om.models.put(pipeline, model_store_name) or pipeline = CTGAN() om.models.put(pipeline, model_store_name)
3.We fit the model using OmegaML runtime:
runtime = om.runtime.require(label=worker_label, always=True) fit_result = runtime.model(model_store_name).fit(Xname=dataset_store_name)
Here we pass just 2 arguments to fit - model_store_name and dataset_store_name.
------------------------------------------------------------------------------------------------------------------------
After I call the fit() I got the error:
fit() takes 3 positional arguments but 4 were given
------------------------------------------------------------------------------------------------------------------------
When I call: celery -A celeryapp events
I see the error status and the Traceback:
Traceback (most recent call last): File "/opt/conda/lib/python3.6/site-packages/celery/app/trace.py", line 382, in trace_task R = retval = fun(*args, **kwargs) File "/opt/conda/lib/python3.6/site-packages/omegaml/celery_util.py", line 97, in __call__ return super().__call__(*args, **kwargs) File "/opt/conda/lib/python3.6/site-packages/celery/app/trace.py", line 641, in __protected_call__ return self.run(*args, **kwargs) File "/opt/conda/lib/python3.6/site-packages/omegaml/tasks.py", line 47, in omega_fit result = self.get_delegate(modelname).fit(*self.delegate_args, **self.delegate_kwargs) TypeError: fit() takes 3 positional arguments but 4 were given
I can not see any places where fit() may use more than 3 arguments.
Could you, please, give me some advice - where to take a look to try to find out the issue core?
And how it can be solved?
P.S. I run the project locally using separate Docker containers for Django and Worker. Backends are installed without any issues.
Comments
Thanks for reporting this. To reproduce the problem, could you list which versions of the following are you using, respectively?
Thank you.
I am using:
For python:
Greetings!
I have found the issue - it was on my side.
In the OmegaML docs I found these lines:
As we can see fit takes 4 parameters (incl self).
But inside my backend I have:
Only 3 parameters.
So, to fix this I just have added *args to my backend model, like:
Now it works fine!
Thanks for your time!
Thank you for the detailed feedback and your solution.
The underlying issue is a difference in signature on the client side and the celery task, which should of course be avoided. Tracking it here https://github.com/omegaml/omegaml/issues/174