Error: fit() takes 3 positional arguments but 4 were given

Greetings!


Currently, the flow to fit a model looks like this:

  1. We do have a custom backend used to fit the model:
from omegaml.backends.basemodel import BaseModelBackend
from sdv.tabular import CTGAN
...


class CTGANBackend(BaseModelBackend):
    """
    OmegaML backend to use with Pytorch
    """

    KIND = 'ctgan.pt'

    @classmethod
    def supports(cls, obj, name, **kwargs):
        return isinstance(obj, CTGAN)

    def fit(self, modelname, Xname, **kwargs):
        """
        SDV documentation:
        CTGAN.fit()
        https://sdv.dev/SDV/user_guides/single_table/ctgan.html#quick-usage

        CTGAN._fit():
        https://sdv.dev/SDV/developer_guides/sdv/tabular.html?highlight=discrete_columns#fit-method
        """
        X = self.data_store.get(Xname)
        model = self.model_store.get(modelname)
        model.fit(X)
        result = self.model_store.put(model, modelname)
        return result

...

As you see, here there are 3 arguments in the fit() method (including self).

All the backends were imported into the project without any issue.


2.We define a pipeline as:

pipeline = CTGAN(epochs=epochs)
om.models.put(pipeline, model_store_name)

or

pipeline = CTGAN()
om.models.put(pipeline, model_store_name)


3.We fit the model using OmegaML runtime:

runtime = om.runtime.require(label=worker_label, always=True)
fit_result = runtime.model(model_store_name).fit(Xname=dataset_store_name)

Here we pass just 2 arguments to fit - model_store_name and dataset_store_name.


------------------------------------------------------------------------------------------------------------------------

After I call the fit() I got the error:

fit() takes 3 positional arguments but 4 were given

------------------------------------------------------------------------------------------------------------------------

When I call: celery -A celeryapp events

I see the error status and the Traceback:

  Traceback (most recent call last):
   File "/opt/conda/lib/python3.6/site-packages/celery/app/trace.py", line 382, in trace_task
    R = retval = fun(*args, **kwargs)
   File "/opt/conda/lib/python3.6/site-packages/omegaml/celery_util.py", line 97, in __call__
    return super().__call__(*args, **kwargs)
   File "/opt/conda/lib/python3.6/site-packages/celery/app/trace.py", line 641, in __protected_call__
    return self.run(*args, **kwargs)
   File "/opt/conda/lib/python3.6/site-packages/omegaml/tasks.py", line 47, in omega_fit
    result = self.get_delegate(modelname).fit(*self.delegate_args, **self.delegate_kwargs)
  TypeError: fit() takes 3 positional arguments but 4 were given


I can not see any places where fit() may use more than 3 arguments.

Could you, please, give me some advice - where to take a look to try to find out the issue core?

And how it can be solved?


P.S. I run the project locally using separate Docker containers for Django and Worker. Backends are installed without any issues.

Comments

Sign In or Register to comment.